Compare commits

...

58 Commits

Author SHA1 Message Date
f2a17553d5 Added CODEOWNERS 2025-06-25 10:07:36 -07:00
889ff20648 v4.0 update v2. (#2420)
* Ex77 forward kernel fix.
2025-06-25 12:56:25 -04:00
dc4817921e v4.0 update. (#2398)
* Ex77 fix.
2025-06-12 09:10:29 -04:00
5c6bca0441 Update requirements.txt (#2390)
Remove the dev suffix in the wheel version
2025-06-10 02:31:49 -04:00
c2ad7c5b20 fix link in readme (#2379) 2025-06-07 07:38:38 -04:00
cc23f6d1e9 fix link (#2377) 2025-06-07 06:00:39 -04:00
5a287538c2 "Update CHANGELOG for 4.0 tagging" (#2374) 2025-06-06 10:07:36 -04:00
8bdbfca682 v4.0 update. (#2371) 2025-06-06 02:39:20 -04:00
2e2af190bd Revert "[ex77] fix mla split; add fwd lse; add bwd varlen (#2366)" (#2370)
This reverts commit f12b1d75c9.
2025-06-05 23:14:57 -04:00
f12b1d75c9 [ex77] fix mla split; add fwd lse; add bwd varlen (#2366) 2025-06-05 18:39:46 -04:00
b244379d9b Merge pull request #2359 from NVIDIA/oss_ci
Initial Workflow Definition for blossom-ci support on CUTLASS GitHub
2025-06-03 14:04:35 -07:00
9d165a3b8e Handle get_masked_trip_count for small length in fmha example (#2292)
* handle get_masked_trip_count for small length

* Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

* Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

---------

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
2025-05-30 22:51:18 -04:00
b9b110a9ea Correct divmod order in example 77 (blackwell fmha) (#2291)
* correct divmod naming

* order bidh/bidb
2025-05-30 22:50:40 -04:00
8206e7a0f5 Pre-compile in CuteDsl/ampere/elementwise_apply.py (#2340) 2025-05-28 10:24:39 -04:00
6316b6f867 Fix typos (#2311)
Signed-off-by: co63oc <co63oc@users.noreply.github.com>
2025-05-23 08:30:10 -04:00
9354bfd7c1 Keep the documentation consistent with the sgemm_1.cu code. (#2285)
* Keep the documentation consistent with the sgemm_1.cu code.

* fix typo

---------

Co-authored-by: zky <zky@126.com>
2025-05-19 22:53:15 -04:00
5e9b8e2a25 fix docx (#2290)
Co-authored-by: xiayongqiang <xiayq1@chinatelecom.cn>
2025-05-19 22:52:37 -04:00
1ec230c4bf Fix typo (#2299)
Needs == for pip to parse the file
2025-05-15 09:38:42 -04:00
f89cd95b16 Update elementwise_add.ipynb (#2298) 2025-05-15 09:38:27 -04:00
f115c3f854 Release v4.0.0 (#2294) 2025-05-13 15:55:29 -04:00
ad7b2f5e84 3.9.2 doc/version (#2279)
* 3.9.2 doc/version

* whitespace
2025-05-04 00:00:15 -04:00
40f124ef27 [CUTLASS] Add GNA to PUBLICATIONS.md (#2276)
Adds "Generalized Neighborhood Attention" to list of publications using
CUTLASS.

https://arxiv.org/abs/2504.16922

Co-authored-by: Ali Hassani <ahassani@nvidia.com>
2025-05-02 16:57:19 -04:00
89f6bf2739 Fix group scale gemm when K==128 (#2275)
Co-authored-by: Jiazhen Han <jiazhenh@nvidia.com>
2025-05-02 15:41:18 -04:00
f535c33634 3.9.1 doc/version change (#2273) 2025-05-01 00:27:00 -04:00
e3cb8a773a Import cuda, cudart, nvrtc lazily (#2251)
* Lazy cuda import

* More lazy cuda import

* More lazy cuda imports

* minor fixes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2025-04-30 23:10:33 -04:00
c4bdfe821c Lazy scipy import (#2250) 2025-04-30 16:10:00 -04:00
b3ce7e12b7 Make cc a positional argument (#2249) 2025-04-30 16:09:25 -04:00
fe75ead92e Import pydot lazily (#2248) 2025-04-30 16:08:17 -04:00
35136f5564 Fix wrong detection of python version for use_rmm. (#2224) 2025-04-30 15:29:33 -04:00
e5b810bed1 Use cudaMemcpyAsync in gemm grouped with kRequiresPrecomputation schedule. (#2256)
Co-authored-by: Yuhang Qi <qiyuhang@bytedance.com>
2025-04-30 15:28:05 -04:00
2b78c2fe31 cherry-pick feature/hopper-blockwise-generalization-optimization (#2270) 2025-04-29 16:47:22 -04:00
697126019e fix blackwell grouped groupwise hang (#2267) 2025-04-29 11:54:20 -04:00
e94e888df3 Update CHANGELOG.md 2025-04-24 21:51:34 -04:00
be73ad20a5 Update CHANGELOG.md for 3.9 2025-04-24 16:54:06 -04:00
f02a7c2976 Update README.md for 3.9 2025-04-24 16:51:45 -04:00
331a1f5b3f cutlass 3.9 update (#2255)
* cutlass 3.9 update

* rebase

* fixes out of shared memory for blockwise Blackwell

* doc format

* fix issue 2253

* disable host ref by default

* fix sm120 smem capacity

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2025-04-24 15:42:40 -04:00
8e345c5c5b fix_missing_stdint (#2199)
* Update config.hpp

* 更新 config.hpp

* 更新 config.hpp
2025-04-23 22:21:22 -04:00
81a43e6d92 Set EpiTile correctly when TileN is not divisible by 32 (#2220)
If TileN is not divisible by 32 (e.g, 208), by default EpiTile would be set
to 128 x 32, which does not compile as TileN is required to divide EpiTileN
2025-04-21 00:02:51 -04:00
ade6376fa0 [SM90] Change register allocation for TileN=208 to avoid spills (#2219)
With the usual register allocation (producer 40, consumer 232) compiling Gemm
with tile shape 256 x 208 (cooperative) or 128 x 208 (pingpong) show lots of
register spilling (e.g. ~3000 bytes spill). For this case we can change
the register allocation to producer 24, consumer 240, which avoids spills.
2025-04-21 00:02:30 -04:00
bb4dd682dd Fix broken links and alt text in cluster launch control docs (#2234)
* Fix broken links in cluster launch control docs

* Improve titles and alt text
2025-04-21 00:01:12 -04:00
5e497243f7 fix: fig link in cute docs (#2216) 2025-04-10 14:51:41 -04:00
b3f3c7758c Update tile_iterator.cu (#2204)
Some typos in comments
2025-04-10 14:49:58 -04:00
9e1b649827 fix-left-inverse-for-nvcc114 (#2196) 2025-04-10 14:48:46 -04:00
5120b21cc3 suppress compilation warnings (#2195) 2025-04-10 14:48:01 -04:00
dd76dec4ef [Doc] Make C++ code more plausible (#2156)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2025-04-10 14:35:46 -04:00
19cc2a5feb add support for sm89 in cute and the unit tests (#2177)
* add support for sm89 in cute and the unit tests

* rebase v3.9 and format code

* minor fix

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2025-04-10 14:16:36 -04:00
09df6ac464 [Doc]fix typo (#2174)
Co-authored-by: wenju.li <wenju.li@deepctr.cn>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2025-04-10 12:46:53 -04:00
df8a550d39 Update mma_atom.hpp (#2159)
remove useless code
2025-04-03 11:42:10 -04:00
79fc51f4b8 v3.9 update (#2213)
Co-authored-by: yuzhai <yuzhai@nvidia.com>
2025-04-03 02:10:16 -04:00
6f4921858b v3.9 update (#2203)
* v3.9 update

* voidD

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
2025-04-02 15:11:18 -04:00
62750a2b75 v3.9 (#2185)
* v3.8 update x

* fix blackwell gg

* doc change

* doc change

* doc change

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-03-21 01:52:23 -04:00
8c4d1dc47d Treat negative zero as equivalent to positive zero in sm90_sparse_gemm_compressor.hpp (#2110)
* Treat negative zero as zero in the sparse gemm compressor

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* format

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* Apply patch

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* sm90_sparse_gemm_compressor.hpp

* test/unit/transform/CMakeLists.txt

* test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp

* include/cutlass/numeric_types.h

---------

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-03-21 01:44:17 -04:00
3fe62887d8 adding blackwell (#2143) 2025-03-17 22:20:40 -04:00
bd03b22f64 fix typo (#2136)
Co-authored-by: XiaoDong <xiaod@nvidia.com>
2025-03-17 22:19:43 -04:00
6c6b78550e Fix SM90 beta=1 hang and stream-K launch errors (#2172)
* Fix stream-K occupancy calculation

* Fix beta=1 hang
2025-03-13 14:07:37 -04:00
06e560d98a Blockwise/Groupwise kernel improvement and programatic dependent launch enablement (#2161)
Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
2025-03-10 14:36:11 -04:00
df18f5e4f5 Improvements for: Groupwise scaling along M for FP8 gemm (#2095)
* fix blockwise fp8 kernels

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* wip, < 128 not working

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* fix < 128

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* reduce diff

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* review comments

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* support partial n blocks

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* fix build errors

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

---------

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-02-27 22:39:29 -05:00
ca4fdbea70 Blockwise and Groupwise GEMM for Blackwell and Improvements for Hopper (#2139)
- Blockwise and Groupwise GEMM improvements for Hopper.
- Blockwise and Groupwise GEMM for Blackwell.
- Blockwise Grouped GEMM for Hopper.
- Static ScalePromotionInterval for Hopper FP8 GEMMs.

Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
2025-02-26 12:44:58 -05:00
800 changed files with 220129 additions and 11081 deletions

1
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1 @@
# This file defines code ownership rules for the repository.

112
.github/workflows/blossom-ci.yml vendored Normal file
View File

@ -0,0 +1,112 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
# A workflow to trigger ci on hybrid infra (github + self hosted runner)
name: Blossom-CI
on:
issue_comment:
types: [created]
workflow_dispatch:
inputs:
platform:
description: 'runs-on argument'
required: false
args:
description: 'argument'
required: false
jobs:
Authorization:
name: Authorization
runs-on: blossom
outputs:
args: ${{ env.args }}
# This job only runs for pull request comments
if: |
(startsWith(github.event.comment.body, '/bot run') ||
startsWith(github.event.comment.body, '/bot kill')) && contains(
fromJson('["zekunf-nv"]'),
github.actor)
steps:
- name: Check if comment is issued by authorized person
run: blossom-ci
env:
OPERATION: 'AUTH'
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
Vulnerability-scan:
name: Vulnerability scan
needs: [Authorization]
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
with:
repository: ${{ fromJson(needs.Authorization.outputs.args).repo }}
ref: ${{ fromJson(needs.Authorization.outputs.args).ref }}
lfs: 'true'
- name: Run blossom action
uses: NVIDIA/blossom-action@main
env:
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
with:
args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }}
args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }}
args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }}
Job-trigger:
name: Start ci job
needs: [Vulnerability-scan]
runs-on: blossom
steps:
- name: Start ci job
run: blossom-ci
env:
OPERATION: 'START-CI-JOB'
CI_SERVER: ${{ secrets.CI_SERVER }}
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Upload-Log:
name: Upload log
runs-on: blossom
if : github.event_name == 'workflow_dispatch'
steps:
- name: Jenkins log for pull request ${{ fromJson(github.event.inputs.args).pr }} (click here)
run: blossom-ci
env:
OPERATION: 'POST-PROCESSING'
CI_SERVER: ${{ secrets.CI_SERVER }}
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,34 +1,142 @@
# NVIDIA CUTLASS Changelog
# Changelog
# CUTLASS 4.x
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
### CuTe DSL
* CuTe DSL, a Python DSL centered around CuTe's abstractions
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
### CUTLASS C++
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
- For example:
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
- Added non-power-of-two tile sizes.
- Improved performance for K-major scale factors.
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
* Support LSE output in Blackwell SM100 FMHA Forward kernel in example 77.
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
* Fix profiler issues which cause no output or not supported error for some kernels.
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
* CuTe changes:
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.9.
# CUTLASS 3.x
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
* Fixed [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM hang issue when problem size K is 128.
* Optimal code generation with CUDA toolkit versions 12.9.
## [3.9.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.1) (2025-04-30)
* Fixed Group Gemm hang issue in CUTLASS 3.x
* Improved Hopper [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM performance.
## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-04-24)
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
- Collective mainloops that target for:
* [Blockscaled datatypes with support for dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
* [Blockscaled datatypes with support for sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- [Blackwell SM120 epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
- [Grouped GEMM with nvfp4 datatype](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
* Set of unit tests that demonstrate the usage of both [sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
* Support for Blackwell SM100 Sparse kernels:
- Collective mainloop that target for
* [SM100 Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
- [Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
- [Blockscaled Sparse GEMM with NVFP4 input data type](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
* Set of unit tests that demonstrate the usage of [sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
* A new [distributed GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
- Enhancement of [blockwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
- Enhancement of [groupwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
- Support for [grouped GEMM with blockwise and groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
- Support for [grouped-wise GEMM](https://github.com/NVIDIA/cutlass/tree/main/tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
- Support for [blockwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
- Support for [groupwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
- Support for [grouped GEMM with blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
* Support `void` as the D element in sm100 kernel epilogues.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.8U1.
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
- [5th generation Blackwell Tensor Core instructions (TCGen05)](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
- Extensions to [Tensor Memory Accelerator](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/pointer.hpp) across CuTe as a first class data locale.
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
- [`make_tmem_copy()`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
- Support for [new variants of LDSM on Blackwell](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
- Various narrow precision [FP4, FP6, and FP8](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/float_subbyte.h)
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
- Blackwell [collective mainloop for convolution kernels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp), [convolution](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/dispatch_policy.hpp), and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- [Blackwell epilogue that supports loading accumulators from `tmem`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and full set of EVT fusions.
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
@ -36,131 +144,131 @@
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
- GEMM with [opt-in collective builder schedules showcasing available recipes](https://github.com/NVIDIA/cutlass/tree/main/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
+ [NVFP4 inputs with BF16 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
+ [NVFP4 inputs with NVFP4 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
- [GEMM with CLC based StreamK scheduler for load balancing](https://github.com/NVIDIA/cutlass/tree/main/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
- Grouped GEMM for [vanilla FP8 data inputs](https://github.com/NVIDIA/cutlass/tree/main/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](https://github.com/NVIDIA/cutlass/tree/main/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
- Convolution kernels for [fprop](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
- [Fused multi-head attention fprop kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
- A new BF16x9 GEMM [kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
* Documentation updates:
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
- Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper.
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
- A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
- [Distributed GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
- Potential API breaking changes:
+ Fix `cute::UniversalCopy` for type safety.
+ No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom.
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
+ A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler).
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
- Optimal code generation with CUDA toolkit versions 12.6.
## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03)
- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
+ [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
+ [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper).
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
- [Hopper structured sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
+ [FP16](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
+ [FP8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
+ [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
+ [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
- Optimal code generation with CUDA toolkit versions 12.6.
## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25)
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu)
- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48)
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
+ [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
+ [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
+ [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
+ [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py).
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/wgmma_sm90.cu)
- [Exposure of L2 `cache_hint`s in TMA copy atoms](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/copy_sm90_tma.hpp#L48)
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#gemm), and
[example 48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
- [TMA store based and EVT supported epilogues](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
+ [FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
+ [int8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
+ [int4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
+ [FP32 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
- [CUDA host adapter](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py).
- Support for residual add (beta != 0) in convolution kernels.
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
- Better support for MSVC as a host compiler.
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
+ [Ampere FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
+ [Turing FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial).
- Extensions to CuTe to support [L2 prefetching](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/copy_sm90_tma.hpp#L1337).
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
- Fixes to greatly reduce build warnings.
- Updates and bugfixes from the community (thanks!)
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
- Statically available [CUTLASS Version macros](./include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](./examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm).
- Statically available [CUTLASS Version macros](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
* Expanded [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
* Expanded [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* [Ampere Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/barrier.h) has been officially released.
* Improved CuTe documentation including improved clarity and depth of [Quickstart](./media/docs/cpp/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
* [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* [Copy Async based Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
* Profiler support for lower-aligned Hopper GEMMs.
* Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion).
* Performance Improvements to [Scatter-Gather Hopper Example](https://github.com/NVIDIA/cutlass/tree/main/examples/52_hopper_gather_scatter_fusion).
* Sub-Byte type fixes and improvements.
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](./include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
@ -172,7 +280,7 @@
* SM80 EVT support in C++ and Python.
* Other SM90 epilogue improvements.
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](./python/README.md) for details.
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) for details.
* SM90 TF32 kernel improvements for all layouts.
* SM90 rasterization direction support in the CUTLASS profiler.
* Improvement for CUTLASS profiler build times.
@ -180,65 +288,67 @@
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](./examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
* New [Epilogue Visitor Tree (EVT)](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
* [Stream-K](./include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](./include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* [Hopper GEMM+Permute](./examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
* New CUTLASS 2D Convolution Python interface. New [example](./examples/python/03_basic_conv2d.ipynb) here.
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
* New [Epilogue Visitor Tree (EVT)](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
* [Stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* [Hopper GEMM+Permute](https://github.com/NVIDIA/cutlass/tree/main/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
* New CUTLASS 2D Convolution Python interface. New [example](https://github.com/NVIDIA/cutlass/tree/main/examples/python/03_basic_conv2d.ipynb) here.
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](./python/README.md) and new [examples](./examples/python).
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
* An [example](./examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python).
* New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
* An [example](https://github.com/NVIDIA/cutlass/tree/main/examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
* Epilogue builders. Similar to mainloop builders (see [example 49](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](./examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* [Row Broadcast](./include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](https://github.com/NVIDIA/cutlass/tree/main/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](https://github.com/NVIDIA/cutlass/tree/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* [Row Broadcast](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
* The GitHub branch is renamed from `master` to `main` in this release.
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md).
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md).
* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
* [CuTe](./media/docs/cpp/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md).
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md).
* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#target-architecture).
* New warp-specialized GEMM [kernel schedules](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
* [CUTLASS library integration](./tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
* Support for [Hopper GEMMs](./examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](./examples/48_hopper_warp_specialized_gemm), [49](./examples/49_hopper_gemm_schedules_with_collective_builder), and [50](./examples/50_hopper_gemm_with_epilogue_swizzle).
* [CUTLASS library integration](https://github.com/NVIDIA/cutlass/tree/main/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
* Support for [Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm), [49](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_schedules_with_collective_builder), and [50](https://github.com/NVIDIA/cutlass/tree/main/examples/50_hopper_gemm_with_epilogue_swizzle).
# CUTLASS 2.x
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
* [Stream-K](./examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
* [Fused multi-head attention Kernel](./examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
* [Dual GEMM](./examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
* Hopper improves [double precision matrix multiplication](./test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
* [BLAS3](./test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
* [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
* Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
* [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
* [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
* [Stream-K](https://github.com/NVIDIA/cutlass/tree/main/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
* [Fused multi-head attention Kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
* [Dual GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
* Hopper improves [double precision matrix multiplication](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
* [BLAS3](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
* [ELL Block Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
* Optimized [Group Conv](https://github.com/NVIDIA/cutlass/tree/main/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
* [Optimized DepthWise Conv](https://github.com/NVIDIA/cutlass/tree/main/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
* [kOptimized](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
* [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
* [kFixedStrideDilation](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
* [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
* [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115).
* [Scripts](https://github.com/NVIDIA/cutlass/tree/main/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
* [FP8 data type definition](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/float8.h) and [conversion routines](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/numeric_conversion.h#L1274-2115).
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
@ -247,54 +357,54 @@
* CUDA 10.2
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
* Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
* Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
* [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
* [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
* [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
* [CUTLASS Python](https://github.com/NVIDIA/cutlass/tree/main/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
* Optimizations for CUTLASS's [Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](https://github.com/NVIDIA/cutlass/tree/main/examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
* Optimizations for [GEMM+Softmax](https://github.com/NVIDIA/cutlass/tree/main/examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
* [Grouped GEMM for Multihead Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
* [GEMM + Layer norm fusion for Ampere](https://github.com/NVIDIA/cutlass/tree/main/examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
* [GEMM Epilogue Permutation Fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
* [Grouped convolution targeting implicit GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
* Standalone [Layernorm](./tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](./tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
* [Depthwise separable convolution](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
* Standalone [Layernorm](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
* [Back-to-back GEMM/CONV](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
* [First layer Convolution kernels](./test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
* [Few channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
* [Fixed channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
* [Unit tests](./test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
* [Python-based instance emitter](./python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
* [First layer Convolution kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
* [Few channels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
* [Fixed channels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
* [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
* [Python-based instance emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
* [HERK](./test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
* [SYRK](./test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
* [SYMM](./test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/symm_operation.py)
* [TRMM](./test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/trmm_operation.py)
* [Unit tests](./test/unit/gemm/device/testbed_rank_k_universal.h)
* [CUTLASS Python](./examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
* [Python-based runtime](./tools/library/scripts/rt.py) interoperable with existing emitters
* [GEMM + Softmax example](./examples/35_gemm_softmax)
* [Gather and Scatter Fusion with GEMM](./examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
* [HERK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/rank_k_operation.py)
* [SYRK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/rank_k_operation.py)
* [SYMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/symm_operation.py)
* [TRMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/trmm_operation.py)
* [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/testbed_rank_k_universal.h)
* [CUTLASS Python](https://github.com/NVIDIA/cutlass/tree/main/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
* [Python-based runtime](https://github.com/NVIDIA/cutlass/tree/main/tools/library/scripts/rt.py) interoperable with existing emitters
* [GEMM + Softmax example](https://github.com/NVIDIA/cutlass/tree/main/examples/35_gemm_softmax)
* [Gather and Scatter Fusion with GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
* It can select random rows in a row major matrix.
* It can select random columns in a column major matrix.
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
* [Back-to-back GEMM/CONV](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
* Supported kernels: GEMM and CONV.
* Supported types: fp16 and int8.
* Supported architectures: Turing and Ampere.
* [Transposed Convolution](./examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
* [Utility functions](./tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
* [Transposed Convolution](https://github.com/NVIDIA/cutlass/tree/main/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
* [Utility functions](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
* Epilogue enhancement:
* Eliminate bank conflicts in int8 tensor core kernels.
* Half2 usage if epilogue compute type is fp16.
* More activation functions: Silu, Hardswish, Leaky Relu.
* New elementwise fusion pattern for [residual block](./include/cutlass/epilogue/thread/linear_combination_residual_block.h).
* [Group GEMM](./examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
* New elementwise fusion pattern for [residual block](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
* [Group GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
@ -304,17 +414,17 @@
* **TF32x3:** emulated single-precision using Tensor Cores
* 45+ TFLOPs on NVIDIA A100
* [GEMM SDK example](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
* [COMPLEX GEMM SDK example](./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
* [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
* [GEMM SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
* [COMPLEX GEMM SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
* [Implicit GEMM Convolution SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
* [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
* [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
* [Conv Fprop SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
* [Conv WGrad SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
* [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
* [cutlass::gemm::device::GemmGrouped](./include/cutlass/gemm/device/gemm_grouped.h)
* [Implicit GEMM Convolution fusion](./examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
* [SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
* [cutlass::gemm::device::GemmGrouped](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_grouped.h)
* [Implicit GEMM Convolution fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
* Updates from the community (thanks!)
@ -324,13 +434,13 @@
* CUDA 10.2
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
* Mainloop fusion for GEMM: [summation over A or B](./examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
* [Strided DGRAD (optimized iterators)](./include/cutlass/conv/kernel/default_conv2d_dgrad.h)
* [Half-precision GELU_taylor activation functions](./include/cutlass/epilogue/thread/activation.h#L196)
* Mainloop fusion for GEMM: [summation over A or B](https://github.com/NVIDIA/cutlass/tree/main/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
* [Strided DGRAD (optimized iterators)](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
* [Half-precision GELU_taylor activation functions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/activation.h#L196)
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
* Tuning and bug fixes to [fused GEMM + GEMM example](./examples/13_two_tensor_op_fusion/)
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
* Tuning and bug fixes to [fused GEMM + GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/)
* Support for smaller than 128b aligned Convolutions: [see examples](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
* Caching of results to accelerate Convolution [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/cache_testbed_output.h)
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
* Corrections and bug fixes reported by the CUTLASS community
* Thank you for filing these issues!
@ -343,24 +453,24 @@
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
* Adopt the new L2 prefetch feature in [cp.async](./include/cutlass/arch/memory.h) and [global load](./include/cutlass/arch/memory_sm80.h)
* Adopt the new L2 prefetch feature in [cp.async](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/memory.h) and [global load](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/memory_sm80.h)
* Fused operators with GEMM and Convolution
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* [Fused broadcast in epilogue](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* [Fused partial reduction in epilogue](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* 64b tensor strides and leading dimensions support for GEMMs
* Affine rank=2 matrix layouts
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h)
* Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
* [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation
* [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/layout/matrix.h)
* Support [FP64 tensor core](https://github.com/NVIDIA/cutlass/tree/main/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
* [Batched GEMV](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemv.cu) preview implementation
* [New strided Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
* Accelerates over previous implementation by cutting down redundant math by 4x
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
* Updates to [quaternion.h](./include/cutlass/quaternion.h) and [functional.h](./include/cutlass/functional.h)
* SDK Example for [GEMM](./examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](./examples/22_quaternion_conv/quaternion_conv.cu)
* [Unit tests for GEMM](./test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](./test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
* Updates to [quaternion.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/quaternion.h) and [functional.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/functional.h)
* SDK Example for [GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](https://github.com/NVIDIA/cutlass/tree/main/examples/22_quaternion_conv/quaternion_conv.cu)
* [Unit tests for GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
* Many improvements to the epilogue.
* Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
* Provide an [option](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
* Performance improvement for FP16 tensor core kernels
* Bug fixes
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
@ -372,14 +482,14 @@
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
* Tensor reductions
* _m_-to-_n_ reductions of tensors with affine layout
* [Specializations](./test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
* [Specializations](./test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
* [Specializations](https://github.com/NVIDIA/cutlass/tree/main/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
* [Specializations](https://github.com/NVIDIA/cutlass/tree/main/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
* Custom reduction functors such as `cutlass::logical_and`
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
* Optimizations for 3-D convolution
* [Optimized tile iterators](./include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
* Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
* [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md)
* [Optimized tile iterators](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
* Full coverage of [forward](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
* [Fused Convolution+Convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/README.md)
* Corrections and bug fixes reported by the CUTLASS community
* Thank you for filing these issues!
@ -394,21 +504,21 @@
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
* [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
* [Documentation](./media/docs/cpp/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
* [Sparse Tensor Core GEMM kernels](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
* [Sparse Tensor Core GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
* Minor Features:
* [Activation functions](./include/cutlass/epilogue/thread/activation.h) such as [GeLU](./include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](./include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
* Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code
* [Floating-point constants](./include/cutlass/constants.h)
* [Activation functions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/activation.h) such as [GeLU](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
* Small [matrix](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/matrix.h) and [quaternion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/quaternion.h) template classes in device code
* [Floating-point constants](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/constants.h)
* NVIDIA Ampere GPU Architecture examples and documentation:
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue)
* [Tensor Float 32](https://github.com/NVIDIA/cutlass/tree/main/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
* [Sparse Tensor Cores](https://github.com/NVIDIA/cutlass/tree/main/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue)
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
@ -428,11 +538,11 @@
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library)
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/cpp/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
* [SDK Examples of Planar Complex GEMMs](./examples/10_planar_complex/planar_complex.cu)
* [SDK Examples of Planar Complex GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/10_planar_complex/planar_complex.cu)
* Minor enhancements and bug fixes
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
@ -442,10 +552,10 @@
* Encapsulated functionality embodying modern C++11 programming techniques
* Optimized containers and data types for efficient, generic, portable device code
* Updates to:
* [Quick start guide](./media/docs/quickstart.md)
* [Quick start guide](./media/docs/cpp/quickstart.md)
* [Documentation](./README.md#documentation)
* [Utilities](./media/docs/utilities.md)
* [CUTLASS Profiler](./media/docs/profiler.md)
* [Utilities](./media/docs/cpp/utilities.md)
* [CUTLASS Profiler](./media/docs/cpp/profiler.md)
* Native Turing Tensor Cores
* Efficient GEMM kernels targeting Turing Tensor Cores
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
@ -538,4 +648,3 @@ SPDX-License-Identifier: BSD-3-Clause
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -102,6 +102,8 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -ftemplate-backtrace-limit=0)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
endif()
@ -173,7 +175,13 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
@ -338,6 +346,10 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
endif()
if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
endif()
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
#
@ -382,7 +394,21 @@ endif()
if (CUTLASS_ENABLE_GDC_FOR_SM90)
message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
endif()
if (NOT DEFINED CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT)
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT ON)
endif()
set(CUTLASS_ENABLE_GDC_FOR_SM100
${CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT}
CACHE BOOL
"Enables Grid Dependency Control (GDC) for SM100 kernels (required for PDL).")
if (CUTLASS_ENABLE_GDC_FOR_SM100)
message(STATUS "Grid Dependency Control (GDC) is enabled for SM100 kernels (required for programmatic dependent launches).")
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM100=1)
endif()
set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!")
@ -427,7 +453,7 @@ if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories.
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v -objtemp) # --keep-dir may not work with nvcc for some directories.
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
endif()
@ -454,6 +480,13 @@ if(UNIX)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
endif()
# Known ctk11.4 issue (fixed later)
# Also see https://stackoverflow.com/questions/64523302/cuda-missing-return-statement-at-end-of-non-void-function-in-constexpr-if-fun
if (CUDA_VERSION VERSION_LESS 11.5.0)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcudafe "--diag_suppress=implicit_return_from_non_void_function" )
message("CUDA_VERSION check pass ${CUDA_VERSION}")
endif()
# Don't leak lineinfo in release builds
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
@ -653,25 +686,6 @@ if (NOT CUTLASS_NAMESPACE STREQUAL "cutlass")
target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE})
endif()
if (NOT DEFINED CUTLASS_REVISION)
find_package(Git QUIET)
execute_process(
COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
RESULT_VARIABLE CUTLASS_REVISION_RESULT
OUTPUT_VARIABLE CUTLASS_REVISION
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if (CUTLASS_REVISION_RESULT)
message(STATUS "CUTLASS Revision: Unable to detect, Git returned code ${CUTLASS_REVISION_RESULT}.")
else()
message(STATUS "CUTLASS Revision: ${CUTLASS_REVISION}")
endif()
endif()
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h
@ -690,6 +704,7 @@ target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
)
install(
@ -1031,6 +1046,7 @@ function(cutlass_generate_profiler_tests NAME)
string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}")
string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}")
string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}")
string(REPLACE "swizzle_size=" "" TEST_NAME "${TEST_NAME}")
string(REGEX REPLACE "verification_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}")
string(REGEX REPLACE "warmup_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
string(REGEX REPLACE "profiling_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")

View File

@ -2,7 +2,7 @@
[README](./README.md#documentation) > **Contributors**
# CUTLASS Developers **
# CUTLASS C++ Developers **
Andrew Kerr<br />
Paul Springer<br />
@ -70,8 +70,49 @@ Shreya Gaur<br />
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
# CUTLASS DSL Developers ***
# CUTE Developers
Albert Di<br />
Albert Xu<br />
Anakin Zheng<br />
Arvin Jou<br />
Brandon Sun<br />
Chenyang Xu<br />
Chunyu Wang<br />
Cris Cecka<br />
dePaul Miller<br />
Edward Cao<br />
Fung Xie<br />
Guray Ozen<br />
Hao Hu<br />
Hong Wang<br />
Jeremy Furtek<br />
Jie Fang <br />
JingZe Cui<br />
Kihiro Bando<br />
Linfeng Zheng<br />
Longsheng Du<br />
Mina Sun<br />
Mindy Li<br />
Pradeep Ramani<br />
Questa Wang<br />
Serif Yesil<br />
Tao Xie<br />
Tina Li<br />
Vicki Wang<br />
Vincent Zhang<br />
Vijay Thakkar<br />
Xiao Dong<br />
Xiaolei Shi<br />
Xinyu Wang<br />
Yihan Chen<br />
Yuhan Li<br />
Zekun Fan<br />
*** _Sorted in alphabetical order._
# CuTe Developers
Cris Cecka<br />
Vijay Thakkar<br />
@ -100,6 +141,9 @@ David Tanner<br />
Tri Dao<br />
Jay Shah<br />
Mehdi Amini<br />
Larry Wu<br />
Justin Holewinski<br />
Timothy Costa<br />
Julien Demouth<br />
Brian Fahs<br />
@ -108,14 +152,11 @@ Michael Goldfarb<br />
Mostafa Hagog<br />
Fei Hu<br />
Alan Kaatz<br />
Tina Li<br />
Wei Liu<br />
Tim Martin<br />
Kevin Siu<br />
Markus Tavenrath<br />
John Tran<br />
Vicki Wang<br />
Fung Xie<br />
Yang Xu<br />
Scott Yokim<br />
Girish Bharambe<br />
@ -128,3 +169,35 @@ Bryce Lelbach<br />
Joel McCormack<br />
Kyrylo Perelygin<br />
Sean Treichler<br />
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -57,7 +57,7 @@ if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])")
elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang")
set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation")
else()
message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
message(FATAL_ERROR "Unknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
endif()
if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30")

188
EULA.txt Normal file
View File

@ -0,0 +1,188 @@
NVIDIA Software License Agreement
IMPORTANT NOTICE PLEASE READ AND AGREE BEFORE USING THE SOFTWARE
This software license agreement (“Agreement”) is a legal agreement between you, whether an individual or entity, (“you”) and NVIDIA Corporation (“NVIDIA”) and governs the use of the NVIDIA CUTLASS DSLs software and materials that NVIDIA delivers to you under this Agreement (“Software”).
NVIDIA and you are each a “party” and collectively the “parties.”
This Agreement can be accepted only by an adult of legal age of majority in the country in which the Software is used.
If you dont have the required age or authority to accept this Agreement, or if you dont accept all the terms and conditions of this Agreement, do not use the Software.
1. License Grants
1.1. License Grant to You. The Software made available by NVIDIA to you is licensed, not sold.
Subject to the terms of this Agreement, NVIDIA grants you a limited, non-exclusive, revocable, non-transferable, and non-sublicensable (except as expressly granted in this Agreement), license to:
a. install and use copies of the Software,
b. configure the Software using configuration files provided (if applicable),
c. modify and create derivative works of any sample or example source code NVIDIA delivers to you as part of the Software (“Derivatives”) (if applicable), and
d. distribute python files in the Software package in source format as incorporated into a software application subject to the following distribution requirements:
i. Your application must have material additional functionality, beyond the included portions of the Software.
ii. The distributable portions of the Software shall only be accessed by your application.
iii. The following notice shall be included in modifications and derivative works of sample source code distributed: “This software contains source code provided by NVIDIA Corporation.”
iv. Unless a developer tool is identified in this Agreement as distributable, it is delivered for your internal use only.
v. The terms under which you distribute your application must be consistent with the terms of this Agreement, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIAs intellectual property rights.
vi. Additionally, you agree that you will protect the privacy, security and legal rights of your application users.
The foregoing (a) through (d) are, collectively, the “Purpose”, and the developed applications are only for use in systems with NVIDIA GPUs.
1.2. License Grant to NVIDIA. Subject to the terms of this Agreement, you grant NVIDIA and its affiliates a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit at NVIDIAs discretion any Derivatives created by or for you.
You may, but are not required to, deliver any Derivatives to NVIDIA.
2. License Restrictions
Your license to use the Software and Derivatives is restricted as stated in this Section 2 (“License Restrictions”).
You will cooperate with NVIDIA and, upon NVIDIAs written request, you will confirm in writing and provide reasonably requested information to verify your compliance with the terms of this Agreement.
You may not:
2.1. Use the Software or Derivatives for any purpose other than the Purpose;
2.2. Sell, rent, sublicense, transfer, distribute or otherwise make available to others (except authorized users as stated in Section 3 (“Authorized Users”)) any portion of the Software or Derivatives, except as expressly granted in Section 1.1 (“License Grant to You”);
2.3. Reverse engineer, decompile, or disassemble the Software components provided in binary form, nor attempt in any other manner to obtain source code of such Software;
2.4. Modify or create derivative works of the Software, except as expressly granted in Section 1.1 (“License Grant to You”);
2.5. Change or remove copyright or other proprietary notices in the Software;
2.6. Bypass, disable, or circumvent any technical limitation, encryption, security, digital rights management or authentication mechanism in the Software;
2.7. Use the Software or Derivatives in any manner that would cause them to become subject to an open source software license, subject to the terms in Section 6 (“Components Under Other Licenses”);
2.8. Use the Software or Derivatives in violation of any applicable law or regulation in relevant jurisdictions
2.9. Indicate that a product or service developed with the Software or Derivatives is sponsored or endorsed by NVIDIA;
2.10. Replace any NVIDIA software components in the Software that are governed by this Agreement with other software that implements NVIDIA APIs;
2.11. Reverse engineer, decompile or disassemble any portion of the output generated using Software elements for the purpose of translating such output artifacts to target a non-NVIDIA platform; or
3. Authorized Users
You may allow employees and contractors of your entity or of your subsidiary(ies), and for educational institutions also enrolled students, to internally access and use the Software as authorized by this Agreement from your secure network to perform the work authorized by this Agreement on your behalf.
You are responsible for the compliance with the terms of this Agreement by your authorized users.
Any act or omission that if committed by you would constitute a breach of this Agreement will be deemed to constitute a breach of this Agreement if committed by your authorized users.
4. Pre-Release
Software versions identified as alpha, beta, preview, early access or otherwise as pre-release (“Pre-Release”) may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability and reliability standards relative to NVIDIA commercial offerings.
You use Pre-Release Software at your own risk. NVIDIA did not design or test the Software for use in production or business-critical systems.
NVIDIA may choose not to make available a commercial version of Pre-Release Software.
NVIDIA may also choose to abandon development and terminate the availability of Pre-Release Software at any time without liability.
5. Updates
NVIDIA may at any time and at its option, change, discontinue, or deprecate any part, or all, of the Software, or change or remove features or functionality, or make available patches, workarounds or other updates to the Software.
Unless the updates are provided with their separate governing terms, they are deemed part of the Software licensed to you under this Agreement, and your continued use of the Software is deemed acceptance of such changes.
6. Components Under Other Licenses
The Software may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms (“Other Licenses”).
The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights;
except that this Agreement will prevail regarding the use of third-party open source software, unless a third-party open source software license requires its license terms to prevail.
Open source software license means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (http://opensource.org), Free Software Foundation (http://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (http://www.spdx.org).
7. Ownership
7.1. NVIDIA Ownership. The Software, including all intellectual property rights, is and will remain the sole and exclusive property of NVIDIA or its licensors.
Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Software, and (b) no other license or right is granted to you by implication, estoppel or otherwise.
7.2. Your Ownership. Subject to the rights of NVIDIA and its suppliers in the Software, which continue to be licensed as stated in this Agreement, even when incorporated in your products or services, and the extent permitted by applicable law, as between you and NVIDIA, you hold all rights, title and interest in and to your products, services and Derivatives you develop as permitted in this Agreement including their respective intellectual property rights.
8. Feedback
You may, but you are not obligated to, provide suggestions, requests, fixes, modifications, enhancements, or other feedback regarding the Software (collectively, “Feedback”).
Feedback, even if designated as confidential by you, will not create any confidentiality obligation for NVIDIA or its affiliates.
If you provide Feedback, you grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit the Feedback at NVIDIAs discretion.
9. Termination
9.1. Termination. This Agreement will automatically terminate without notice from NVIDIA if you fail to comply with any of the terms in this Agreement or if you commence or participate in any legal proceeding against NVIDIA with respect to the Software.
Additionally, either party may terminate this Agreement at any time with thirty (30) days advance written notice to the other party.
9.2. Effect of Termination. Upon any expiration or termination of this Agreement, you will promptly (a) stop using and return, delete or destroy NVIDIA confidential information and all Software received under this Agreement, and (b) delete or destroy Derivatives created under this Agreement, unless an authorized NVIDIA representative provides prior written approval that you may keep a copy of the Derivatives solely for archival purposes.
Upon written request, you will certify in writing that you have complied with your obligations under this Section 9.2 (“Effect of Termination”).
9.3. Survival. Section 1.2 (“License Grant to NVIDIA”), Section 5 (“Updates”), Section 6 (“Components Under Other Licenses”), Section 7 (“Ownership”), Section 8 (“Feedback), Section 9.2 (“Effect of Termination”), Section 9.3 (“Survival”), Section 10 (“Disclaimer of Warranties”), Section 11 (“Limitation of Liability”), Section 12 (“Use in Mission Critical Applications”), Section 13 (“Governing Law and Jurisdiction”), Section 14 (“Indemnity”) and Section 15 (“General”) will survive any expiration or termination of this Agreement.
10. Disclaimer of Warranties
THE SOFTWARE IS PROVIDED BY NVIDIA AS-IS AND WITH ALL FAULTS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA DISCLAIMS ALL WARRANTIES AND REPRESENTATIONS OF ANY KIND, WHETHER
EXPRESS, IMPLIED OR STATUTORY, RELATING TO OR ARISING UNDER THIS AGREEMENT, INCLUDING, WITHOUT LIMITATION, THE WARRANTIES OF TITLE, NONINFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. NVIDIA DOES NOT WARRANT OR ASSUME RESPONSIBILITY FOR THE ACCURACY OR COMPLETENESS OF ANY THIRD-PARTY INFORMATION, TEXT, GRAPHICS, LINKS CONTAINED IN THE SOFTWARE.
WITHOUT LIMITING THE FOREGOING, NVIDIA DOES NOT WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS, ANY DEFECTS OR ERRORS WILL BE CORRECTED, ANY CERTAIN CONTENT WILL BE AVAILABLE; OR THAT THE SOFTWARE IS FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS. NO INFORMATION OR ADVICE GIVEN BY NVIDIA WILL IN ANY WAY INCREASE THE SCOPE OF ANY WARRANTY EXPRESSLY PROVIDED IN THIS AGREEMENT.
NVIDIA does not warrant or assume responsibility for the accuracy or completeness of any third-party information, text, graphics or links contained in the Software.
11. Limitations of Liability
11.1. EXCLUSIONS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL NVIDIA BE LIABLE FOR ANY (I) INDIRECT, PUNITIVE, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES, OR (ii) DAMAGES FOR (a) THE COST OF PROCURING SUBSTITUTE GOODS, OR (b) LOSS OF PROFITS, REVENUES, USE, DATA OR GOODWILL ARISING OUT OF OR RELATED TO THIS AGREEMENT, WHETHER BASED ON BREACH OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY, OR OTHERWISE, AND EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES AND EVEN IF A PARTYS REMEDIES FAIL THEIR ESSENTIAL PURPOSE.
11.2. DAMAGES CAP. ADDITIONALLY, TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIAS TOTAL CUMULATIVE AGGREGATE LIABILITY FOR ANY AND ALL LIABILITIES, OBLIGATIONS OR CLAIMS ARISING OUT OF OR RELATED TO THIS AGREEMENT WILL NOT EXCEED FIVE U.S. DOLLARS (US$5).
12. Use in Mission Critical Applications
You acknowledge that the Software provided under this Agreement is not designed or tested by NVIDIA for use in any system or application where the use or failure of such system or application developed with NVIDIAs Software could result in injury, death or catastrophic damage (each, a “Mission Critical Application”).
Examples of Mission Critical Applications include use in avionics, navigation, autonomous vehicle applications, AI solutions for automotive products, military, medical, life support or other mission-critical or life-critical applications.
NVIDIA will not be liable to you or any third party, in whole or in part, for any claims or damages arising from these uses.
You are solely responsible for ensuring that systems and applications developed with the Software include sufficient safety and redundancy features and comply with all applicable legal and regulatory standards and requirements.
13. Governing Law and Jurisdiction
This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods.
The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts;
except that either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
14. Indemnity
By using the Software you agree to defend, indemnify and hold harmless NVIDIA and its affiliates and their respective officers, directors, employees and agents from and against any claims, disputes, demands, liabilities, damages, losses, costs and expenses arising out of or in any way connected with (i) products or services that have been developed or deployed with or use the Software, or claims that they violate laws, or infringe, violate, or misappropriate any third party right;
or (ii) use of the Software in breach of the terms of this Agreement.
15. General
15.1. Independent Contractors.
The parties are independent contractors, and this Agreement does not create a joint venture, partnership, agency, or other form of business association between the parties.
Neither party will have the power to bind the other party or incur any obligation on its behalf without the other partys prior written consent.
Nothing in this Agreement prevents either party from participating in similar arrangements with third parties.
15.2. No Assignment.
NVIDIA may assign, delegate or transfer its rights or obligations under this Agreement by any means or operation of law.
You may not, without NVIDIAs prior written consent, assign, delegate or transfer any of your rights or obligations under this Agreement by any means or operation of law, and any attempt to do so is null and void.
15.3. No Waiver.
No failure or delay by a party to enforce any term or obligation of this Agreement will operate as a waiver by that party, or prevent the enforcement of such term or obligation later.
15.4. Trade Compliance.
You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations.
You confirm (a) your understanding that export or reexport of certain NVIDIA products or technologies may require a license or other approval from appropriate authorities and (b) that you will not export or reexport any products or technology, directly or indirectly, without first obtaining any required license or other approval from appropriate authorities, (i) to any countries that are subject to any U.S. or local export restrictions (currently including, but not necessarily limited to, Belarus, Cuba, Iran, North Korea, Russia, Syria, the Region of Crimea, Donetsk Peoples Republic Region and Luhansk Peoples Republic Region);
(ii) to any end-user who you know or have reason to know will utilize them in the design, development or production of nuclear, chemical or biological weapons, missiles, rocket systems, unmanned air vehicles capable of a maximum range of at least 300 kilometers, regardless of payload, or intended for military end-use, or any weapons of mass destruction;
(iii) to any end-user who has been prohibited from participating in the U.S. or local export transactions by any governing authority;
or (iv) to any known military or military-intelligence end-user or for any known military or military-intelligence end-use in accordance with U.S. trade compliance laws and regulations.
15.5. Government Rights.
The Software, documentation and technology (“Protected Items”) are “Commercial products” as this term is defined at 48 C.F.R.
2.101, consisting of “commercial computer software” and “commercial computer software documentation” as such terms are used in, respectively, 48 C.F.R.
12.212 and 48 C.F.R. 227.7202 & 252.227-7014(a)(1). Before any Protected Items are supplied to the U.S. Government, you will (i) inform the U.S. Government in writing that the Protected Items are and must be treated as commercial computer software and commercial computer software documentation developed at private expense;
(ii) inform the U.S. Government that the Protected Items are provided subject to the terms of the Agreement;
and (iii) mark the Protected Items as commercial computer software and commercial computer software documentation developed at private expense.
In no event will you permit the U.S. Government to acquire rights in Protected Items beyond those specified in 48 C.F.R.
52.227-19(b)(1)-(2) or 252.227-7013(c) except as expressly approved by NVIDIA in writing.
15.6. Notices.
Please direct your legal notices or other correspondence to legalnotices@nvidia.com with a copy mailed to NVIDIA Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, United States of America, Attention: Legal Department.
If NVIDIA needs to contact you, you consent to receive the notices by email and agree that such notices will satisfy any legal communication requirements.
15.7. Severability.
If a court of competent jurisdiction rules that a provision of this Agreement is unenforceable, that provision will be deemed modified to the extent necessary to make it enforceable and the remainder of this Agreement will continue in full force and effect.
15.8. Amendment.
Any amendment to this Agreement must be in writing and signed by authorized representatives of both parties.
15.9. Construction.
The headings in the Agreement are included solely for convenience and are not intended to affect the meaning or interpretation of the Agreement.
As required by the context of the Agreement, the singular of a term includes the plural and vice versa.
15.10. Force Majeure.
Neither party will be liable during any period where an event or circumstance prevents or delays that party from performing its obligations under this Agreement and that event or circumstance: (i) is not within the reasonable control of that party and is not the result of that partys negligence, and (ii) cannot be overcome or avoided by that party using reasonably diligent efforts.
15.11. Entire Agreement.
Regarding the subject matter of this Agreement, the parties agree that (a) this Agreement constitutes the entire and exclusive agreement between the parties and supersedes all prior and contemporaneous communications and (b) any additional or different terms or conditions, whether contained in purchase orders, order acknowledgments, invoices or otherwise, will not be binding and are null and void.
(v. May 8, 2025)

View File

@ -25,3 +25,10 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Certain files within this repository are subject to separate licensing terms:
- The files located in the `python/CuTeDSL` directory are licensed under the
NVIDIA End User License Agreement (EULA). Please refer to
https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
for the full terms.

View File

@ -2,10 +2,16 @@
## 2025
- ["Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts"](https://arxiv.org/abs/2502.19811). Shulai Zhang, Ningxin Zheng, Haibin Lin, Ziheng Jiang, Wenlei Bao, Chengquan Jiang, Qi Hou, Weihao Cui, Size Zheng, Li-Wen Chang, Quan Chen, Xin Liu. _arXiv_, February 2025.
- ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025.
- ["Generalized Neighborhood Attention: Multi-dimensional Sparse Attention at the Speed of Light"](https://arxiv.org/abs/2504.16922). Ali Hassani, Fengzhe Zhou, Aditya Kane, Jiannan Huang, Chieh-Yun Chen, Min Shi, Steven Walton, Markus Hoehnerbach, Vijay Thakkar, Michael Isaev, Qinsheng Zhang, Bing Xu, Haicheng Wu, Wen-mei Hwu, Ming-Yu Liu, Humphrey Shi. _arXiv_, April 2025.
## 2024
- ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024.
- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024.
- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024.
@ -32,9 +38,9 @@
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023.
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, February 2023.
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023.
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, February 2023.
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
@ -64,3 +70,35 @@
"](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
- ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

269
README.md
View File

@ -1,127 +1,123 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 3.8.0
# CUTLASS 4.0.0
_CUTLASS 3.8.0 - January 2025_
_CUTLASS 4.0.0 - May 2025_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
these "moving parts" into reusable, modular software components abstracted by C++ template
classes. Primitives for different levels of a conceptual parallelization hierarchy
can be specialized and tuned via custom tiling sizes, data types,
and other algorithmic policy. The resulting flexibility simplifies their use
as building blocks within custom kernels and applications.
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
hierarchical decomposition and data movement. CUTLASS decomposes these "moving parts" into reusable, modular
software components and abstractions.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
Primitives for different levels of a conceptual parallelization hierarchy can be specialized and tuned
via custom tiling sizes, data types, and other algorithmic policy. The resulting flexibility simplifies
their use as building blocks within custom kernels and applications.
CUTLASS has been providing CUDA C++ template abstractions for high-performance linear algebra since 2017 and
these abstractions provide extensive support for a wide range of computations including
mixed-precision computations, specialized data-movement (async copy) and
multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16,
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
[FP32 emulation via tensor core instruction](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
8b floating point types (e5m2 and e4m3),
block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8),
narrow integer types (4 and 8b signed and unsigned integers),
and binary 1b data types (where architectures allow for the
native support of such data types).
CUTLASS demonstrates optimal matrix multiply operations
native support of such data types) across NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
To this rich ecosystem of C++ based kernel programming abstractions, CUTLASS 4 adds CUTLASS DSLs. These are Python native interfaces for writing high-performance CUDA kernels based on core CUTLASS and CuTe concepts without any performance compromises. This allows for a much smoother learning curve, orders of magnitude faster compile times, native integration with DL frameworks without writing glue code, and much more intuitive metaprogramming that does not require deep C++ expertise.
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions — exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
CuTe DSL demonstrates optimal matrix multiply and other linear algebra operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
NVIDIA's Ampere, Hopper, and Blackwell architectures.
In addition to GEMMs, CUTLASS implements high-performance convolution via
the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
We believe it will become an indispensable tool for students, researchers, and performance
engineers alike — flattening the learning curve of GPU programming, rapidly prototyping kernel
designs, and bringing optimized solutions into production.
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
See the [functionality docs](./media/docs/functionality.md) for a more comprehensive
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
architecture.
To get started quickly - please refer :
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
# What's New in CUTLASS 3.8
# What's New in CUTLASS 4.0
CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture.
For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8.
### CuTe DSL
* CuTe DSL, a Python DSL centered around CuTe's abstractions
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
- Support for mixed input GEMM kernels on Hopper in the profiler.
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
* Documentation updates:
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
### CUTLASS C++
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
- For example:
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
- Added non-power-of-two tile sizes.
- Improved performance for K-major scale factors.
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
* Support LSE output in Blackwell SM100 FMHA Forward kernel in example 77.
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
* Fix profiler issues which cause no output or not supported error for some kernels.
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
* CuTe changes:
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.9.
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
CUTLASS team is working on a fix.
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.**
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
# Performance
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit nearly optimal utilization of peak theoretical throughput. The figure below
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU.
<p align="center"><img src=media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg></p>
![ALT](media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg "")
The two figures below show the continual CUTLASS performance improvements
The two figures below show the continual CUTLASS performance improvements
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
CUTLASS 3.1.
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
Tensor Core operations are implemented using CUDA's
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
Tensor Core operations are implemented using CUDA's
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
![ALT](media/images/cutlass-3.5.1-gemm-peak-performance.png "")
![ALT](media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png "")
# CuTe
@ -143,7 +139,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
This greatly simplifies the design and improves code composability and readability.
More documentation specific to CuTe can be found in its
[dedicated documentation directory](./media/docs/cute/00_quickstart.md).
[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
# Compatibility
@ -153,7 +149,7 @@ Minimum requirements:
- Compiler: Must support at least C++17
- CUDA Toolkit version: 11.4
CUTLASS requires a C++17 host compiler and
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions.
@ -190,6 +186,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
## Target Architecture
@ -213,19 +210,19 @@ the kernel is expected to fail with a runtime error.
```
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
```
Or
Or
```
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
```
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
products has a different compute capability than the one underpinning
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
compiled for Blackwell SM100 architecture with arch conditional features
(using `sm100a`) are not compatible with RTX 50 series GPUs.
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
products has a different compute capability than the one underpinning
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
compiled for Blackwell SM100 architecture with arch conditional features
(using `sm100a`) are not compatible with RTX 50 series GPUs.
Please refer to the [functionality documentation](./media/docs/functionality.md)
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
for details on which kernels require which target architectures.
# Documentation
@ -233,22 +230,22 @@ for details on which kernels require which target architectures.
CUTLASS is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
kernels in the same stream, and how it is used in CUTLASS.
# Resources
@ -268,7 +265,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
paths.
CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.
@ -308,12 +305,12 @@ All tests should pass on supported platforms, though the exact number of tests m
# Project Structure
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
and template concepts defined in the CUTLASS project.
A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
## CUTLASS Template Library
@ -337,7 +334,7 @@ include/ # client applications should target this directory
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
thread/ # simt code that can be performed within a CUDA thread
transform/ # code specialized for layout, type, and domain transformations
* # core vocabulary types, containers, and basic numeric operations
@ -362,7 +359,7 @@ include/ # client applications should target this directory
### CUTLASS SDK Examples
[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations.
[CUTLASS SDK examples](https://github.com/NVIDIA/cutlass/tree/main/examples) apply CUTLASS templates to implement basic computations.
### Tools
@ -375,9 +372,9 @@ tools/
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
# CUTLASS Library
util/ # CUTLASS Utilities - contains numerous helper classes for
include/ # manging tensors in device memory, reference
include/ # managing tensors in device memory, reference
cutlass/ # implementations for GEMM, random initialization
util/ # of tensors, and I/O.
```
@ -387,7 +384,7 @@ tools/
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
# Performance Profiling
@ -401,7 +398,7 @@ $ make cutlass_profiler -j16
By default, only one tile size is instantiated for each data type, math instruction, and layout.
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
Beware, this results in *tens of thousands* of kernels and long build times.
Beware, this results in *tens of thousands* of kernels and long build times.
This would also result in a large binary size and on some platforms linker to fail on building the library.
Therefore, it's highly recommended to generate only a subset of kernels as demonstrated in the sub-section below.
```bash
@ -412,13 +409,13 @@ $ make cutlass_profiler -j16
## Building a subset of GEMM and Convolution kernels (_reduced_ build times)
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
wildcard characters may be used to reduce the set of kernels. The following examples show building exactly one
or a subset of kernels for NVIDIA Ampere and Turing architecture:
### Building a subset Tensor Core GEMM kernels
To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture,
To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture,
use the below cmake command line:
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*gemm_f16_*_nt_align8
@ -507,7 +504,7 @@ $ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
### Building a subset of Tensor Core Convolution kernels
To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation
To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation
and FP16 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*fprop_optimized_f16
@ -555,7 +552,7 @@ reference_device: Passed
### Building one Convolution CUDA kernel
To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation
To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation
and FP32 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
```bash
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
@ -603,14 +600,14 @@ reference_device: Passed
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
# About
CUTLASS is released by NVIDIA Corporation as Open Source software under the
CUTLASS is released by NVIDIA Corporation as Open Source software under the
[3-clause "New" BSD license](LICENSE.txt).
# Contributors

View File

@ -36,7 +36,7 @@ set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
function(cutlass_generate_kernel_filter_and_testlists_files)
function(cutlass_generate_kernel_filter_and_testlist_files)
set(options)
set(oneValueArgs TEST_SET_NAME)
@ -59,30 +59,30 @@ function(cutlass_generate_kernel_filter_and_testlists_files)
)
if(NOT cutlass_FILTER_GENERATION_RESULT EQUAL 0)
message(FATAL_ERROR "Error generating kernel filters and testlists files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
message(FATAL_ERROR "Error generating kernel filters and testlist files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
endif()
endfunction()
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
set(PROFILER_ARCH_LIST 100a)
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
message(FATAL_ERROR "Only SM100a compute capability is supported with profiler-based unit tests")
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")
endif()
endforeach()
if(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 0)
message(STATUS "Building for L0 profiler-based functional regressions")
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l0)
cutlass_generate_kernel_filter_and_testlist_files(TEST_SET_NAME kernel_testlist_l0)
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
elseif (CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 1)
message(STATUS "Building for L1 profiler-based functional regressions")
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l1)
cutlass_generate_kernel_filter_and_testlist_files(TEST_SET_NAME kernel_testlist_l1)
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")

View File

@ -34,7 +34,7 @@
addressable memory, and then store it back into addressable memory.
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to
and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines
and from addressable memory. The PredicatedTileIterator accepts a ThreadMap type, which defines
the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined
thread mappings to be specified.
@ -124,7 +124,7 @@ __global__ void copy(
cudaError_t TestTileIterator(int M, int K) {
// For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects
// For this example, we chose a <64, 4> tile shape. The PredicatedTileIterator expects
// PitchLinearShape and PitchLinear layout.
using Shape = cutlass::layout::PitchLinearShape<64, 4>;
using Layout = cutlass::layout::PitchLinear;
@ -136,7 +136,7 @@ cudaError_t TestTileIterator(int M, int K) {
// dimension then along the strided dimension.
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<Shape, kThreads>;
// Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types
// Define the PredicatedTileIterator, using TileShape, Element, Layout, and ThreadMap types
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
Shape, Element, Layout, 1, ThreadMap>;

View File

@ -115,4 +115,3 @@ SPDX-License-Identifier: BSD-3-Clause
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -2,3 +2,35 @@
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -165,3 +165,35 @@ Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{
void initialize(Options const& options) {
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
@ -429,7 +429,7 @@ void initialize(Options const& options) {
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);

View File

@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions {
void initialize(Options const& options) {
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
@ -347,7 +347,7 @@ void initialize(Options const& options) {
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);

View File

@ -288,7 +288,7 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
void initialize(MixedDtypeOptions const& options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = cutlass::ceil_div(options.k, options.g);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
@ -313,7 +313,7 @@ void initialize(MixedDtypeOptions const& options) {
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);

View File

@ -41,3 +41,35 @@ We are currently optimizing the following cases:
* Optimizations for memory bound cases.
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -208,20 +208,6 @@ bool initialize_tensor(
return true;
}
template <typename Element>
bool initialize_quant_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed = 2023) {
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <class Element>
bool initialize_scale(
cutlass::DeviceAllocation<Element>& block,
@ -232,10 +218,8 @@ bool initialize_scale(
float scope_max = 1.0f, scope_min = 1.0f;
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
const float max_dequant_val = 4.f;
const float min_dequant_val = 0.5f;
scope_max = max_dequant_val / elt_max_f;
scope_min = min_dequant_val / elt_max_f;
scope_max = 2.f;
scope_min = 0.1f;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));

View File

@ -489,7 +489,7 @@ int run(Options &options)
std::cout << " Batches : " << options.l << std::endl;
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
}
return 0;

View File

@ -124,7 +124,7 @@ struct CooperativeConfig {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_256,_128,_128>;
using ClusterShape = Shape<_2,_2,_1>;
using ClusterShape = Shape<_1,_2,_1>;
};
struct PingpongConfig {
@ -296,14 +296,14 @@ struct Options {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = alignment * ((rand() % 64) + 1);
if (m < 0) {
m = alignment * ((rand() % 64));
}
if (n < 1) {
n = alignment * ((rand() % 64) + 1);
if (n < 0) {
n = alignment * ((rand() % 64));
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
if (k < 0) {
k = alignment * ((rand() % 64));
}
problem_sizes_host.push_back({m, n, k});
}
@ -333,19 +333,9 @@ struct Options {
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
extent.at(i) = std::atoi(tokens.at(i).c_str());
}
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
groups = static_cast<int>(problem_sizes_host.size());
@ -500,10 +490,27 @@ void initialize(const Options &options) {
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
// If the current group's matrix has size 0, set the pointer to nullptr
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
ptr_A_host.at(i) = nullptr;
} else {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
}
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
ptr_B_host.at(i) = nullptr;
} else {
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
}
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
ptr_C_host.at(i) = nullptr;
} else {
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
}
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
ptr_D_host.at(i) = nullptr;
} else {
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
@ -539,9 +546,10 @@ void initialize(const Options &options) {
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_A, seed + 2021);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
initialize_block(block_C, seed + 2023);
initialize_block(block_D, seed + 2024);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
@ -653,6 +661,13 @@ int run(Options &options, bool host_problem_shapes_available = true)
allocate(options);
initialize(options);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
// Instantiate CUTLASS kernel depending on templates
GemmT gemm;
@ -700,14 +715,8 @@ int run(Options &options, bool host_problem_shapes_available = true)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
}
return 0;

View File

@ -770,9 +770,6 @@ int main(int argc, char const** argv) {
bool satisfied;
if (props.major < 10) {
// Pre-Blackwell
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4);
satisfied &= (props.major > 8) || (props.major == 8 && props.minor == 9);
}
else {
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8);
@ -786,7 +783,6 @@ int main(int argc, char const** argv) {
std::cout
<< "CUTLASS's FP8 SM89 example requires an NVIDIA GPU with compute capability 8.9 or greater "
<< "and CUDA toolkit version 12.4 or later"
<< " (12.8 or later needed for SM100+)"
<< std::endl;
return 0;

View File

@ -207,3 +207,35 @@ With this in mind, this example kernel has the following limitations:
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -26,11 +26,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include_directories(
.
)
set(TEST_PREFETCH_CASE --m=8192 --n=64 --k=8192 --iterations=0)
cutlass_example_add_executable(
63_hopper_gemm_with_weight_prefetch
63_hopper_gemm_with_weight_prefetch.cu
)
TEST_COMMAND_OPTIONS
TEST_PREFETCH_CASE
)
target_include_directories(63_hopper_gemm_with_weight_prefetch PUBLIC .)

View File

@ -74,9 +74,40 @@ echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
However, note that the example still runs a single GEMM, and most of the performance improvement
is expected in end to end applications.
## Limitations
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
memory barrier before issuing every single TMA load, and in many cases this will slow down
prefetching to the point of being almost ineffective.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -42,7 +42,6 @@
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/arch/grid_dependency_control.h"
@ -288,7 +287,7 @@ struct CollectiveMma<
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
@ -445,7 +444,7 @@ struct CollectiveMma<
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true;
cutlass::arch::launch_dependent_grids();
}
@ -453,7 +452,7 @@ struct CollectiveMma<
// Advance smem_pipe_write
++smem_pipe_write;
}
if (!disable_gdc && !launch_dep_grids) {
if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids();
}
}
@ -533,7 +532,7 @@ struct CollectiveMma<
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true;
cutlass::arch::launch_dependent_grids();
}
@ -541,7 +540,7 @@ struct CollectiveMma<
// Advance smem_pipe_write
++smem_pipe_write;
}
if (!disable_gdc && !launch_dep_grids) {
if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids();
}
}
@ -634,9 +633,9 @@ struct CollectiveMma<
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
@ -854,7 +853,7 @@ struct CollectiveMma<
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();

View File

@ -362,11 +362,11 @@ public:
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
blockDim.x * blockDim.y * blockDim.z,
/*reserved_named_barriers_*/ 14);
/*id*/ 0);
// Prefetcher warp doesn't arrive on this barrier.
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
/*reserved_named_barriers_*/ 15);
/*id*/ 1);
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
__syncwarp();

View File

@ -120,8 +120,7 @@
#include "helper.h"
// Distributed GEMM helpers
#include "util/benchmark.h"
#include "util/device_copy.h"
#include "dist_gemm_helpers.h"
using namespace cute;
@ -133,8 +132,8 @@ using namespace cute;
using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
// Distributed GEMM tiling/sharding schedule
// Choices:
@ -253,7 +252,8 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
@ -345,8 +345,8 @@ struct Result {
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -804,17 +804,18 @@ int run(Options &options) {
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
// CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example
// and must have compute capability at least 90.
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
// Some necessary cuda graph APIs were only introduced in CUDA 12.6.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
std::cerr << "This example requires CUDA 12.6 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
@ -834,10 +835,10 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
if (props.major != 9 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater)." << std::endl;
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
<< "(compute capability 90)." << std::endl;
return 0;
}
@ -858,8 +859,12 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
run(options);
#else
std::cerr
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl;
return 0;
#endif
return 0;

View File

@ -62,3 +62,40 @@ procedure is the same, simply modify the following line in the example:
```cpp
using TP = _8;
```
## References
* [Distributed GEMM Blog](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b)
* [Distributed GEMM Talk on CUDA Mode](https://www.youtube.com/watch?v=NHRTCQBZokg)
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -17,6 +17,8 @@ Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit ar
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
CUDA graph APIs.
The minimum CUDA driver version for running this example is [560.28.03](https://docs.nvidia.com/cuda/archive/12.6.0/cuda-toolkit-release-notes/index.html#id5).
### Hardware / driver settings
This example requires Hopper GPUs with NVLink network.
@ -84,3 +86,35 @@ GPU5 OK OK OK OK OK X OK OK
GPU6 OK OK OK OK OK OK X OK
GPU7 OK OK OK OK OK OK OK X
```
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -75,11 +75,11 @@
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
// Includes from examples directory
#include "helper.h"
#include "hopper_fp8_commandline.hpp"
#include "reference/host/gemm_with_blockwise_scaling.h"
using namespace cute;
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
@ -123,7 +123,13 @@ using ArchTag = cutlass::arch::Sm90; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(TileShape{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
@ -143,8 +149,8 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
@ -190,20 +196,21 @@ StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
StrideAux stride_aux;
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
uint64_t seed;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
uint32_t mma_promotion_interval;
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
@ -251,117 +258,116 @@ struct Result
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, bits_input);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options<RasterOrderOptions> &options) {
// Find Block Scaling tensor shapes based on problem shape and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
stride_aux = stride_D;
// Layout SFA and SFB represent logically broadcasting data in CuTe.
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
// indecies in the tensor map to the same offset.
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(blockscale_m * options.l, blockscale_k);
auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
blockscale_tensor_A.resize(blockscale_a_coord);
@ -398,8 +404,6 @@ void initialize(const Options<RasterOrderOptions> &options) {
blockscale_tensor_A.sync_device();
blockscale_tensor_B.sync_device();
mma_promotion_interval = 4;
if (options.save_aux) {
tensor_aux.resize(c_coord);
tensor_aux.sync_device();
@ -434,14 +438,18 @@ void initialize(const Options<RasterOrderOptions> &options) {
if (IsDFp8 && options.save_amax) {
abs_max_D.resize(cutlass::make_Coord(1));
initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
abs_max_D.sync_device();
reference_abs_max_D.resize(cutlass::make_Coord(1));
initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
}
if (IsAuxFp8 && options.save_aux && options.save_amax) {
abs_max_aux.resize(cutlass::make_Coord(1));
initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
abs_max_aux.sync_device();
reference_abs_max_aux.resize(cutlass::make_Coord(1));
initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
}
}
@ -455,9 +463,10 @@ typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &op
stride_A,
tensor_B.device_data(),
stride_B,
mma_promotion_interval,
blockscale_tensor_A.device_data(),
blockscale_tensor_B.device_data()
layout_SFA,
blockscale_tensor_B.device_data(),
layout_SFB
},
{
{}, // epilogue.thread
@ -511,13 +520,6 @@ bool verify(const Options<RasterOrderOptions> &options) {
// Compute reference output
//
// Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(
@ -550,28 +552,18 @@ bool verify(const Options<RasterOrderOptions> &options) {
)
);
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
)
);
auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA);
auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
decltype(A), decltype(B),
decltype(blockscale_A), decltype(blockscale_B),
TileShape> mainloop_params{
A, B, // Operand Tensors
blockscale_A, blockscale_B // Blockwise scaling Tensors
};
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
@ -604,29 +596,40 @@ bool verify(const Options<RasterOrderOptions> &options) {
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
bool passed = true;
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (false) {
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
}
#if 0
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
#endif
if (IsDFp8 && options.save_amax) {
abs_max_D.sync_host();
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
if (options.save_aux) {
tensor_aux.sync_host();
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (IsAuxFp8 && options.save_amax) {
abs_max_aux.sync_host();
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
}
@ -662,20 +665,22 @@ int run(Options<RasterOrderOptions> &options)
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
if (options.verify) {
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
// if (!result.passed) {
// exit(-1);
// }
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
else {
result.passed = true;
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
if (iter == options.warmup)
timer.start();
CUTLASS_CHECK(gemm.run());
}
timer.stop();
@ -700,7 +705,7 @@ int run(Options<RasterOrderOptions> &options)
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
return result.passed;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
@ -746,7 +751,9 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
bool passed = run<Gemm>(options);
if (!passed)
return -1;
#endif
return 0;

View File

@ -75,11 +75,11 @@
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
// Includes from examples directory
#include "helper.h"
#include "hopper_fp8_commandline.hpp"
#include "reference/host/gemm_with_groupwise_scaling.h"
using namespace cute;
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
@ -120,55 +120,30 @@ using ElementAccumulator = float; // E
using ElementBlockScale = float; // Element type for blockscaling during accumulation
using ElementCompute = float; // Element type for epilogue computation
using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor
// Given TileShape = Shape<_128,_128,_128>:
// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor)
// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling
// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling
// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling
template <int ScaleGranularityM_, int ScaleGranularityN_>
struct GroupScaleConfig {
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
static constexpr int ScaleGranularityM = ScaleGranularityM_;
static constexpr int ScaleGranularityN = ScaleGranularityN_;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
"FP8 scaling granularity must evenly divide tile shape along M.");
static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile,
"FP8 scaling granularity must evenly divide tile shape along N.");
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
};
using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>;
using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>;
using GroupScale2D1DConfig = GroupScaleConfig<size<0>(TileShape_{}), 1>;
using GroupScale2D2DConfig = GroupScaleConfig<size<0>(TileShape_{}), size<1>(TileShape_{})>;
template <typename ScheduleConfig>
struct GroupScaleGemm {
using ArchTag = typename ScheduleConfig::ArchTag;
using OperatorClass = typename ScheduleConfig::OperatorClass;
using TileShape = typename ScheduleConfig::TileShape;
using ClusterShape = typename ScheduleConfig::ClusterShape;
using KernelSchedule = typename ScheduleConfig::KernelSchedule;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using FusionOperation = typename ScheduleConfig::FusionOperation;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
@ -179,10 +154,10 @@ struct GroupScaleGemm {
FusionOperation
>::CollectiveOp;
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
@ -191,38 +166,26 @@ struct GroupScaleGemm {
KernelSchedule
>::CollectiveOp;
using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloopWithGroupWiseScaling,
CollectiveEpilogue
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::StreamKScheduler
>;
using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloopWithGroupWiseScaling,
CollectiveEpilogue,
cutlass::gemm::StreamKScheduler
>;
using GemmDefault = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelDefault>;
using GemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelStreamK>;
};
using GroupScale1D1DGemm = GroupScaleGemm<GroupScale1D1DConfig>;
using GroupScale1D2DGemm = GroupScaleGemm<GroupScale1D2DConfig>;
using GroupScale2D1DGemm = GroupScaleGemm<GroupScale2D1DConfig>;
using GroupScale2D2DGemm = GroupScaleGemm<GroupScale2D2DConfig>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename GroupScale1D1DGemm::GemmDefault::EpilogueOutputOp;
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
using StrideA = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideA;
using StrideB = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideB;
using StrideC = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideC;
using StrideD = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideD;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideAux = StrideD;
constexpr bool IsDFp8 =
@ -242,20 +205,22 @@ StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
StrideAux stride_aux;
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
uint64_t seed;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
uint32_t mma_promotion_interval;
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
@ -303,120 +268,114 @@ struct Result
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
return true;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, bits_input);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
template <typename GroupScaleConfig>
void initialize(const Options<RasterOrderOptions> &options) {
using TileShape = typename GroupScaleConfig::TileShape;
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
auto blockscale_k = cute::get<2>(blockscale_shape);
assert(options.m % ScaleGranularityM == 0);
assert(options.n % ScaleGranularityN == 0);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
stride_aux = stride_D;
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k);
auto groupscale_b_coord = cutlass::make_Coord(groupscale_n * options.l, blockscale_k);
auto groupscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto groupscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
@ -453,8 +412,6 @@ void initialize(const Options<RasterOrderOptions> &options) {
blockscale_tensor_A.sync_device();
blockscale_tensor_B.sync_device();
mma_promotion_interval = 4;
if (options.save_aux) {
tensor_aux.resize(c_coord);
tensor_aux.sync_device();
@ -515,9 +472,10 @@ GemmArguments args_from_options(const Options<RasterOrderOptions> &options)
stride_A,
tensor_B.device_data(),
stride_B,
mma_promotion_interval,
blockscale_tensor_A.device_data(),
blockscale_tensor_B.device_data()
layout_SFA,
blockscale_tensor_B.device_data(),
layout_SFB
},
{
{}, // epilogue.thread
@ -567,18 +525,11 @@ GemmArguments args_from_options(const Options<RasterOrderOptions> &options)
}
/// Don't know why the compiler does not like verify() being templated...
bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile, const int ScaleNsPerTile) {
bool verify(const Options<RasterOrderOptions> &options) {
//
// Compute reference output
//
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(
@ -611,28 +562,18 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
)
);
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
)
);
auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA);
auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
decltype(A), decltype(B),
decltype(blockscale_A), decltype(blockscale_B),
TileShape_> mainloop_params{
A, B, // Operand Tensors
blockscale_A, blockscale_B // Groupwise scaling Tensors
};
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
@ -665,29 +606,40 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
bool passed = true;
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (false) {
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
}
#if 0
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
#endif
if (IsDFp8 && options.save_amax) {
abs_max_D.sync_host();
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
if (options.save_aux) {
tensor_aux.sync_host();
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
if (IsAuxFp8 && options.save_amax) {
abs_max_aux.sync_host();
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
}
}
@ -695,16 +647,34 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
}
/// Execute a given example GEMM computation
template <typename GroupScaleConfig, typename Gemm>
int run(Options<RasterOrderOptions> &options)
{
using TileShape = typename GroupScaleConfig::TileShape;
const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
int run(Options<RasterOrderOptions> &options) {
initialize<GroupScaleConfig>(options);
bool skip = false;
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
if (options.m < ScaleGranularityM) {
std::cout << " Skippig (m size: " << options.m << " less than ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}
if (options.n < ScaleGranularityN) {
std::cout << " Skippig (n size: " << options.n << " less than ScaleGranularityN: " << ScaleGranularityN << "):" << std::endl;
skip = true;
}
if (options.k < size<2>(TileShape{})) {
std::cout << " Skippig (k size: " << options.k << " less than TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
skip = true;
}
if (!skip) std::cout << " Running... " << std::endl;
else return -1;
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
@ -729,20 +699,22 @@ int run(Options<RasterOrderOptions> &options)
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
if (options.verify) {
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
// if (!result.passed) {
// exit(-1);
// }
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
else {
result.passed = true;
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
if (iter == options.warmup)
timer.start();
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
@ -762,17 +734,13 @@ int run(Options<RasterOrderOptions> &options)
raster = "Along M";
}
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
fflush(stdout);
}
return 0;
return result.passed;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
@ -818,27 +786,10 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
std::cout << "Basic split-K GEMM kernel" << std::endl;
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmDefault>(options);
std::cout << std::endl;
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmDefault>(options);
std::cout << std::endl;
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmDefault>(options);
std::cout << std::endl;
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmDefault>(options);
std::cout << std::endl;
std::cout << std::endl;
std::cout << "StreamK GEMM kernel" << std::endl;
run<GroupScale1D1DConfig, GroupScale1D1DGemm::GemmStreamK>(options);
std::cout << std::endl;
run<GroupScale1D2DConfig, GroupScale1D2DGemm::GemmStreamK>(options);
std::cout << std::endl;
run<GroupScale2D1DConfig, GroupScale2D1DGemm::GemmStreamK>(options);
std::cout << std::endl;
run<GroupScale2D2DConfig, GroupScale2D2DGemm::GemmStreamK>(options);
std::cout << std::endl;
bool passed = true;
passed = run(options);
if (!passed)
return -1;
#endif
return 0;

View File

@ -34,6 +34,7 @@ template<typename RasterOrderOptions>
struct Options {
bool help = false;
bool verify = true;
float alpha = 1.f, beta = 0.f;
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
@ -41,9 +42,12 @@ struct Options {
bool save_aux = true;
bool save_amax = true;
int iterations = 1000;
int warmup = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
RasterOrderOptions raster;
int swizzle;
float epsilon = 0.02f;
float non_zero_floor = 1.f;
// Parses the command line
void parse(int argc, char const **args) {
@ -68,7 +72,11 @@ struct Options {
cmd.get_cmd_line_argument("device_scale", device_scale, false);
cmd.get_cmd_line_argument("save_aux", save_aux, true);
cmd.get_cmd_line_argument("save_amax", save_amax, true);
cmd.get_cmd_line_argument("warmup", warmup);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("verify", verify);
cmd.get_cmd_line_argument("epsilon", epsilon);
cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
@ -89,8 +97,8 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n"
<< " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
@ -109,11 +117,14 @@ struct Options {
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --verify=<bool> Verify the results.\n\n"
<< " --epsilon=<float> The epsilon value for comparing the results.\n\n"
<< " --non-zero-floor=<float> The none zero floor for comparing the results.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
<< "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}

View File

@ -1,504 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for GETT in host-side code.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/gemm.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/relatively_equal.h"
#include <iostream>
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::reference::host {
template<class T, class = void>
struct ElementTraits {
using type = T;
};
template<class T>
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
using type = decltype(std::declval<T>().get());
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_, // (N, K, L)
class TensorScaleA_, // (m, k, L)
class TensorScaleB_, // (n, k, L)
class TileShape_
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
using TensorA = TensorA_;
using TensorB = TensorB_;
using EngineA = typename TensorA::engine_type;
using LayoutA = typename TensorA::layout_type;
using EngineB = typename TensorB::engine_type;
using LayoutB = typename TensorB::layout_type;
using TensorScaleA = TensorScaleA_;
using TensorScaleB = TensorScaleB_;
using TileShape = TileShape_;
using EngineScaleA = typename TensorScaleA::engine_type;
using EngineScaleB = typename TensorScaleB::engine_type;
TensorA A{};
TensorB B{};
TensorScaleA ScaleA{};
TensorScaleB ScaleB{};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = TensorD_, // (M, 1)
class TensorAux_ = TensorD_, // (M, N, L)
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
using ElementScalingFactor = ElementScalingFactor_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using TensorC = TensorC_;
using TensorD = TensorD_;
using TensorAux = TensorAux_;
using VectorBias = VectorBias_;
using VectorAlpha = VectorAlpha_;
using VectorBeta = VectorBeta_;
using ActivationFunctor = ActivationFunctor_;
using BiasBinaryOp = BiasBinaryOp_;
using EngineC = typename TensorC::engine_type;
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
TensorC C{};
TensorD D{};
VectorBias Bias{};
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
ElementScalingFactor scale_a = ElementScalingFactor(1);
ElementScalingFactor scale_b = ElementScalingFactor(1);
ElementScalingFactor scale_c = ElementScalingFactor(1);
ElementScalingFactor scale_d = ElementScalingFactor(1);
ElementScalingFactor scale_aux = ElementScalingFactor(1);
bool beta_per_channel_scaling = false;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel with Blockwise scaling
template <
class MainloopParams,
class EpilogueParams
>
void Gett(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
#if defined(_OPENMP)
#pragma omp parallel for collapse(3)
#endif
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
gett_mainloop(mainloop_params, m, n, l, acc);
gett_epilogue(epilogue_params, m, n, l, acc);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Mainloop
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_mainloop(
MainloopParams const& mainloop_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
multiplies<ElementAccumulator> scale_op;
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
// Tempo accumulators to seperate blockwise accumulation
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
// Zero out accumulators
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
int64_t block_m = m / kBlockM;
int64_t block_n = n / kBlockN;
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, l);
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l);
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
// Load Blockwise scaling factor from blockscale Tensors for A and B
int64_t block_k = k / kBlockK;
ElementBlockScaleA scale_a = blockscale_A[block_k];
ElementBlockScaleB scale_b = blockscale_B[block_k];
// Load A
ElementAccumulator a_frag[kBlockM];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
} else {
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// Load B
ElementAccumulator b_frag[kBlockN];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
} else {
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
}
}
// Apply Blockwise-scaling at kBlockK boundary
// (a) Apply block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
// (b) Zero-out partial temporary (acc_temp),
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a * scale_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Epilogue
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_epilogue(
EpilogueParams const& epilogue_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
using ElementAux = typename EpilogueParams::TensorAux::value_type;
using ElementBias = typename EpilogueParams::VectorBias::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
constexpr bool IsReLUAuxNeeded =
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
constexpr bool IsClamp =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
constexpr bool IsBackpropFusion =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
// Input related converter
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
NumericConverter<ElementCompute, ElementC> source_converter;
NumericConverter<ElementCompute, ElementBias> bias_converter;
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
// Scale related converter
NumericConverter<ElementCompute, ElementScalar> scale_converter;
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
// Abs max converter
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
// Output related converter
NumericConverter<ElementD, ElementCompute> destination_converter;
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
NumericConverter<ElementBias, ElementCompute> dBias_converter;
// Epilogue operations
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
multiplies<ElementCompute> mul;
plus<ElementCompute> add;
// Activation operation
ActivationFunctor activation;
// Bias binary operation
BiasBinaryOp bias_op;
// Do conversion
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
// Init local var
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
converted_beta = mul(converted_beta, converted_scale_c);
ElementCompute inter_accum[kBlockM][kBlockN];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
ElementCompute local_dBias = ElementCompute(0);
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
// per-row alpha
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
}
ElementCompute output = mul(converted_alpha, converted_acc);
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
output = bias_op(output, converted_bias);
}
if (raw_pointer_cast(epilogue_params.C.data())) {
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
// per-row beta
if (epilogue_params.Vbeta.data()) {
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
}
output = epilogue_fma(converted_beta, converted_src, output);
}
if constexpr (IsBackpropFusion) {
ElementAux aux_input = ElementAux(0);
if (raw_pointer_cast(epilogue_params.Aux.data())) {
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
}
output = activation(output, aux_source_converter(aux_input));
local_dBias = add(local_dBias, output);
}
else {
if (raw_pointer_cast(epilogue_params.Aux.data())) {
auto aux_output = output;
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
}
if constexpr (IsReLUAuxNeeded) {
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
} else {
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
}
}
if constexpr (IsClamp) { // Treat Clamp as ReLU
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
}
else {
output = activation(output);
}
}
if constexpr (IsScalingAndAmaxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_output = amax_op(local_abs_max_output, output);
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
}
inter_accum[m_b][n_b] = ElementCompute(output);
}
} // n_b
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
local_dBias = add(local_dBias, converted_dBias);
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
}
}
} // m_b
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
}
}
}
#if defined(_OPENMP)
#pragma omp critical(Abs_Max_Data_Update)
#endif
{
if constexpr (IsScalingAndAmaxOutputNeeded) {
if (epilogue_params.abs_max_D) {
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
}
}
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
if (epilogue_params.abs_max_Aux) {
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM - General Matrix-Matrix contraction without conjugation options
template <
class MainloopParams,
class EpilogueParams
>
void Gemm3x(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
using namespace cute;
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
"with Batchmode are supported");
// Lower the Matrix-Multiplication with Blockwise scaling (Gemm3x) to a Tensor Contraction (Gett).
Gett(mainloop_params, epilogue_params);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // cutlass::reference::host
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,511 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for GETT in host-side code.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/gemm.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/relatively_equal.h"
#include <iostream>
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::reference::host {
template<class T, class = void>
struct ElementTraits {
using type = T;
};
template<class T>
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
using type = decltype(std::declval<T>().get());
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_, // (N, K, L)
class TensorScaleA_, // (m, k, L)
class TensorScaleB_, // (n, k, L)
class TileShape_
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
using TensorA = TensorA_;
using TensorB = TensorB_;
using EngineA = typename TensorA::engine_type;
using LayoutA = typename TensorA::layout_type;
using EngineB = typename TensorB::engine_type;
using LayoutB = typename TensorB::layout_type;
using TensorScaleA = TensorScaleA_;
using TensorScaleB = TensorScaleB_;
using TileShape = TileShape_;
using EngineScaleA = typename TensorScaleA::engine_type;
using EngineScaleB = typename TensorScaleB::engine_type;
TensorA A{};
TensorB B{};
TensorScaleA ScaleA{};
TensorScaleB ScaleB{};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = TensorD_, // (M, 1)
class TensorAux_ = TensorD_, // (M, N, L)
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
using ElementScalingFactor = ElementScalingFactor_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using TensorC = TensorC_;
using TensorD = TensorD_;
using TensorAux = TensorAux_;
using VectorBias = VectorBias_;
using VectorAlpha = VectorAlpha_;
using VectorBeta = VectorBeta_;
using ActivationFunctor = ActivationFunctor_;
using BiasBinaryOp = BiasBinaryOp_;
using EngineC = typename TensorC::engine_type;
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
TensorC C{};
TensorD D{};
VectorBias Bias{};
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
ElementScalingFactor scale_a = ElementScalingFactor(1);
ElementScalingFactor scale_b = ElementScalingFactor(1);
ElementScalingFactor scale_c = ElementScalingFactor(1);
ElementScalingFactor scale_d = ElementScalingFactor(1);
ElementScalingFactor scale_aux = ElementScalingFactor(1);
bool beta_per_channel_scaling = false;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling
template <
class MainloopParams,
class EpilogueParams
>
void Gett(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
#if defined(_OPENMP)
#pragma omp parallel for collapse(3)
#endif
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
gett_mainloop(mainloop_params, m, n, l, acc);
gett_epilogue(epilogue_params, m, n, l, acc);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Mainloop
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_mainloop(
MainloopParams const& mainloop_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
multiplies<ElementAccumulator> scale_op;
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
// Tempo accumulators to seperate blockwise accumulation
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
// Zero out accumulators
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
int64_t block_m = m / kBlockM;
int64_t block_n = n / kBlockN;
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
// Load Blockwise scaling factor from blockscale Tensors for B
int64_t block_k = k / kBlockK;
cute::Tensor scale_a = blockscale_A(_, block_k);
cute::Tensor scale_b = blockscale_B(_, block_k);
// Load A
ElementAccumulator a_frag[kBlockM];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
} else {
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// Load B
ElementAccumulator b_frag[kBlockN];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
} else {
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
}
}
// Apply Groupwise-scaling at kBlockK boundary
// (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
// (b) Zero-out partial temporary (acc_temp),
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Epilogue
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_epilogue(
EpilogueParams const& epilogue_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
using ElementAux = typename EpilogueParams::TensorAux::value_type;
using ElementBias = typename EpilogueParams::VectorBias::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
constexpr bool IsReLUAuxNeeded =
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
constexpr bool IsClamp =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
constexpr bool IsBackpropFusion =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
// Input related converter
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
NumericConverter<ElementCompute, ElementC> source_converter;
NumericConverter<ElementCompute, ElementBias> bias_converter;
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
// Scale related converter
NumericConverter<ElementCompute, ElementScalar> scale_converter;
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
// Abs max converter
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
// Output related converter
NumericConverter<ElementD, ElementCompute> destination_converter;
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
NumericConverter<ElementBias, ElementCompute> dBias_converter;
// Epilogue operations
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
multiplies<ElementCompute> mul;
plus<ElementCompute> add;
// Activation operation
ActivationFunctor activation;
// Bias binary operation
BiasBinaryOp bias_op;
// Do conversion
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
// Init local var
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
converted_beta = mul(converted_beta, converted_scale_c);
ElementCompute inter_accum[kBlockM][kBlockN];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
ElementCompute local_dBias = ElementCompute(0);
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
// per-row alpha
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
}
ElementCompute output = mul(converted_alpha, converted_acc);
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
output = bias_op(output, converted_bias);
}
if (raw_pointer_cast(epilogue_params.C.data())) {
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
// per-row beta
if (epilogue_params.Vbeta.data()) {
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
}
output = epilogue_fma(converted_beta, converted_src, output);
}
if constexpr (IsBackpropFusion) {
ElementAux aux_input = ElementAux(0);
if (raw_pointer_cast(epilogue_params.Aux.data())) {
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
}
output = activation(output, aux_source_converter(aux_input));
local_dBias = add(local_dBias, output);
}
else {
if (raw_pointer_cast(epilogue_params.Aux.data())) {
auto aux_output = output;
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
}
if constexpr (IsReLUAuxNeeded) {
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
} else {
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
}
}
if constexpr (IsClamp) { // Treat Clamp as ReLU
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
}
else {
output = activation(output);
}
}
if constexpr (IsScalingAndAmaxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_output = amax_op(local_abs_max_output, output);
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
}
inter_accum[m_b][n_b] = ElementCompute(output);
}
} // n_b
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
local_dBias = add(local_dBias, converted_dBias);
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
}
}
} // m_b
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
}
}
}
#if defined(_OPENMP)
#pragma omp critical(Abs_Max_Data_Update)
#endif
{
if constexpr (IsScalingAndAmaxOutputNeeded) {
if (epilogue_params.abs_max_D) {
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
}
}
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
if (epilogue_params.abs_max_Aux) {
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM - General Matrix-Matrix contraction without conjugation options
template <
class MainloopParams,
class EpilogueParams
>
void Gemm3x(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
using namespace cute;
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
"with Batchmode are supported");
// Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett).
Gett(mainloop_params, epilogue_params);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // cutlass::reference::host
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,771 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
descriptors to move between groups/problem_count (represented by groups).
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling \
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
Where the test_benchmark.txt may look as such:
0 256x512x128
1 256x512x512
2 512x256x128
3 256x256x128
4 256x512x1024
5 1024x512x128 and so on
*/
#include <iostream>
#include <optional>
#include <fstream>
#include <sstream>
#include <vector>
#include <cfloat>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
// Includes from examples directory
#include "helper.h"
#include "hopper_fp8_commandline.hpp"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementBlockScale = float; // Element type for blockscaling during accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC *, AlignmentC,
ElementD, LayoutD *, AlignmentD,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopWithGroupWiseScaling,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
"ElementAccumulator and ElementBlockScale should be same datatype");
/// Initialization
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_blockscale_A;
std::vector<int64_t> offset_blockscale_B;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
uint64_t seed;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementD> block_D;
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
cutlass::DeviceAllocation<const ElementA *> ptr_A;
cutlass::DeviceAllocation<const ElementB *> ptr_B;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<ElementD *> ptr_D;
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023,
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
double _scope_max, _scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
_scope_max = 2;
_scope_min = 0;
} else if (bits_input <= 8) {
_scope_max = 2;
_scope_min = -2;
} else if (bits_input == 16) {
_scope_max = 5;
_scope_min = -5;
} else {
_scope_max = 8;
_scope_min = -8;
}
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
_scope_max = scope_max;
}
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
_scope_min = scope_min;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
return true;
}
/// Allocates device-side data
template <typename OptionType>
void allocate(const OptionType &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
int64_t total_elements_blockscale_A = 0;
int64_t total_elements_blockscale_B = 0;
offset_A.clear();
offset_B.clear();
offset_C.clear();
offset_D.clear();
offset_blockscale_A.clear();
offset_blockscale_B.clear();
stride_A_host.clear();
stride_B_host.clear();
stride_C_host.clear();
stride_D_host.clear();
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_blockscale_A.push_back(total_elements_blockscale_A);
offset_blockscale_B.push_back(total_elements_blockscale_B);
int64_t elements_A = M * K;
int64_t elements_B = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA));
int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB));
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_blockscale_A += elements_blockscale_A;
total_elements_blockscale_B += elements_blockscale_B;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
layout_SFA_host.push_back(group_layout_SFA);
layout_SFB_host.push_back(group_layout_SFB);
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
blockscale_block_A.reset(total_elements_blockscale_A);
blockscale_block_B.reset(total_elements_blockscale_B);
}
/// Initialize operands to be used in the GEMM and reference GEMM
template <typename OptionType>
void initialize(const OptionType &options) {
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementD *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
alpha_host.clear();
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_blockscale_A.reset(options.groups);
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
ptr_blockscale_B.reset(options.groups);
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
layout_SFA.reset(options.groups);
layout_SFA.copy_from_host(layout_SFA_host.data());
layout_SFB.reset(options.groups);
layout_SFB.copy_from_host(layout_SFB_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2022);
initialize_block(block_B, seed + 2023);
initialize_block(block_C, seed + 2024);
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
template<typename GemmArguments, typename OptionType>
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
{
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
int device_id = 0;
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename Gemm::GemmKernel>(device_id);
GemmArguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
ptr_blockscale_A.get(), layout_SFA.get(),
ptr_blockscale_B.get(), layout_SFB.get()
},
{
{}, // epilogue.thread
ptr_C.get(), stride_C.get(),
ptr_D.get(), stride_D.get()
},
kernel_hw_info
};
auto &fusion_args = arguments.epilogue.thread;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
arguments.scheduler.raster_order = options.raster_order;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
template <typename OptionType>
bool verify(const OptionType &options) {
//
// Compute reference output
//
std::vector<ElementA> block_A_host(block_A.size());
std::vector<ElementB> block_B_host(block_B.size());
std::vector<ElementC> block_C_host(block_C.size());
std::vector<ElementD> block_D_host_kernel(block_D.size());
std::vector<ElementD> block_D_host_ref(block_D.size());
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
block_A.copy_to_host(block_A_host.data());
block_B.copy_to_host(block_B_host.data());
block_C.copy_to_host(block_C_host.data());
block_D.copy_to_host(block_D_host_kernel.data());
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
cute::make_layout(
cute::make_shape(m, k, 1),
stride_A_host.at(group_idx)
)
);
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
cute::make_layout(
cute::make_shape(n, k, 1),
stride_B_host.at(group_idx)
)
);
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
cute::make_layout(
cute::make_shape(m, n, 1),
stride_C_host.at(group_idx)
)
);
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
cute::make_layout(
cute::make_shape(m, n, 1),
stride_D_host.at(group_idx)
)
);
auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
layout_SFA_host.at(group_idx));
auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
layout_SFB_host.at(group_idx));
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
unused_t, // Aux
unused_t, // valpha
unused_t // vbeta
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = alpha_host.at(group_idx);
epilogue_params.beta = beta_host.at(group_idx);
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Check if output from CUTLASS kernel and reference kernel are equal or not
auto this_group_passed = std::equal(
// std::execution::par_unseq,
block_D_host_ref.data() + offset_D.at(group_idx),
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
block_D_host_kernel.data() + offset_D.at(group_idx)
);
passed &= this_group_passed;
#if 0
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
#endif
}
return passed;
}
/// Execute a given example GEMM computation
template <typename OptionType>
int run(OptionType &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
fflush(stdout);
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
//
// Parse options
//
Options<ProblemShape> options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
std::cout << "Running tests with host problem shapes:" << std::endl;
run(options, true);
std::cout << "Running tests without host problem shapes:" << std::endl;
run(options, false);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,779 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
descriptors to move between groups/problem_count (represented by groups).
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
5. This example is tuned specifically for the sparse groups case, where the number of active groups (groups
with non-zero problem count) is much smaller than the total number of groups.
Examples:
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups \
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
Where the test_benchmark.txt may look as such:
0 256x512x128
1 256x512x512
2 512x256x128
3 256x256x128
4 256x512x1024
5 1024x512x128 and so on
*/
#include <iostream>
#include <optional>
#include <fstream>
#include <sstream>
#include <vector>
#include <cfloat>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
// Includes from examples directory
#include "helper.h"
#include "hopper_fp8_commandline.hpp"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementBlockScale = float; // Element type for blockscaling during accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
static constexpr int ScaleGranularityM = 1;
static constexpr int ScaleGranularityN = 128;
static constexpr int ScaleGranularityK = 128;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC *, AlignmentC,
ElementD, LayoutD *, AlignmentD,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopWithGroupWiseScaling,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
"ElementAccumulator and ElementBlockScale should be same datatype");
/// Initialization
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_blockscale_A;
std::vector<int64_t> offset_blockscale_B;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
uint64_t seed;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementD> block_D;
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
cutlass::DeviceAllocation<const ElementA *> ptr_A;
cutlass::DeviceAllocation<const ElementB *> ptr_B;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<ElementD *> ptr_D;
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
double gbps;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
double gbps = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), gbps(gbps), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023,
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
double _scope_max, _scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
_scope_max = 2;
_scope_min = 0;
} else if (bits_input <= 8) {
_scope_max = 2;
_scope_min = -2;
} else if (bits_input == 16) {
_scope_max = 5;
_scope_min = -5;
} else {
_scope_max = 8;
_scope_min = -8;
}
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
_scope_max = scope_max;
}
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
_scope_min = scope_min;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
return true;
}
/// Allocates device-side data
template <typename OptionType>
void allocate(const OptionType &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
int64_t total_elements_blockscale_A = 0;
int64_t total_elements_blockscale_B = 0;
offset_A.clear();
offset_B.clear();
offset_C.clear();
offset_D.clear();
offset_blockscale_A.clear();
offset_blockscale_B.clear();
stride_A_host.clear();
stride_B_host.clear();
stride_C_host.clear();
stride_D_host.clear();
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_after_alignment_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_blockscale_A.push_back(total_elements_blockscale_A);
offset_blockscale_B.push_back(total_elements_blockscale_B);
int64_t elements_A = M * K;
int64_t elements_B = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA));
int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB));
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_blockscale_A += elements_blockscale_A;
total_elements_blockscale_B += elements_blockscale_B;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
layout_SFA_host.push_back(group_layout_SFA);
layout_SFB_host.push_back(group_layout_SFB);
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
blockscale_block_A.reset(total_elements_blockscale_A);
blockscale_block_B.reset(total_elements_blockscale_B);
}
/// Initialize operands to be used in the GEMM and reference GEMM
template <typename OptionType>
void initialize(const OptionType &options) {
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_after_alignment_host.data());
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementD *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
alpha_host.clear();
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_blockscale_A.reset(options.groups);
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
ptr_blockscale_B.reset(options.groups);
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
layout_SFA.reset(options.groups);
layout_SFA.copy_from_host(layout_SFA_host.data());
layout_SFB.reset(options.groups);
layout_SFB.copy_from_host(layout_SFB_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2022);
initialize_block(block_B, seed + 2023);
initialize_block(block_C, seed + 2024);
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
template<typename GemmArguments, typename OptionType>
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
{
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
int device_id = 0;
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename Gemm::GemmKernel>(device_id);
GemmArguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_after_alignment_host.data() : (decltype(options.problem_sizes_after_alignment_host.data())) nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
ptr_blockscale_A.get(), layout_SFA.get(),
ptr_blockscale_B.get(), layout_SFB.get()
},
{
{}, // epilogue.thread
ptr_C.get(), stride_C.get(),
ptr_D.get(), stride_D.get()
},
kernel_hw_info
};
auto &fusion_args = arguments.epilogue.thread;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
arguments.scheduler.raster_order = options.raster_order;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
template <typename OptionType>
bool verify(const OptionType &options) {
//
// Compute reference output
//
std::vector<ElementA> block_A_host(block_A.size());
std::vector<ElementB> block_B_host(block_B.size());
std::vector<ElementC> block_C_host(block_C.size());
std::vector<ElementD> block_D_host_kernel(block_D.size());
std::vector<ElementD> block_D_host_ref(block_D.size());
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
block_A.copy_to_host(block_A_host.data());
block_B.copy_to_host(block_B_host.data());
block_C.copy_to_host(block_C_host.data());
block_D.copy_to_host(block_D_host_kernel.data());
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
cute::make_layout(
cute::make_shape(m, k, 1),
stride_A_host.at(group_idx)
)
);
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
cute::make_layout(
cute::make_shape(n, k, 1),
stride_B_host.at(group_idx)
)
);
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
cute::make_layout(
cute::make_shape(m, n, 1),
stride_C_host.at(group_idx)
)
);
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
cute::make_layout(
cute::make_shape(m, n, 1),
stride_D_host.at(group_idx)
)
);
auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
layout_SFA_host.at(group_idx));
auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
layout_SFB_host.at(group_idx));
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = alpha_host.at(group_idx);
epilogue_params.beta = beta_host.at(group_idx);
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Check if output from CUTLASS kernel and reference kernel are equal or not
auto this_group_passed = std::equal(
// std::execution::par_unseq,
block_D_host_ref.data() + offset_D.at(group_idx),
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
block_D_host_kernel.data() + offset_D.at(group_idx)
);
passed &= this_group_passed;
#if 0
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
#endif
}
return passed;
}
/// Execute a given example GEMM computation
template <typename OptionType>
int run(OptionType &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
result.gbps = options.template gbps<ElementA,
ElementB,
ElementC,
ElementD,
ElementBlockScale,
TileShape,
ScaleMsPerTile,
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
//
// Parse options
//
Options<ProblemShape> options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
run(options, true);
std::cout << "Running tests without host problem shapes:" << std::endl;
run(options, false);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,84 @@
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=0) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=512 --groups=50 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
cutlass_example_add_executable(
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
)
# MSVC will fail to compile this example with the following error:
# fatal error C1083: Cannot open source file: <Some Mojibake>: No such file or directory [...\examples\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.vcxproj]
# This is a known issue and we are working on a fix.
if (NOT MSVC)
cutlass_example_add_executable(
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
)
endif()

View File

@ -0,0 +1,270 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Command line options parsing
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
template<typename _ProblemShape>
struct Options {
using ProblemShape = _ProblemShape;
bool help = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, groups = 10;
std::string benchmark_path;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_after_alignment_host;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
int const tma_alignment_bits = 128;
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
int const k_alignment = 128;
int const m_alignment = 128;
int const n_alignment = 128;
RasterOrderOptions raster_order;
int swizzle;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') {
raster_order = RasterOrderOptions::AlongN;
}
else if (raster_char == 'M' || raster_char == 'm') {
raster_order = RasterOrderOptions::AlongM;
}
else if (raster_char == 'H' || raster_char == 'h') {
raster_order = RasterOrderOptions::Heuristic;
}
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
// Decide how to initialize the problems
if (!benchmark_path.empty()) {
if (!benchmark_problems()) {
problem_sizes_after_alignment_host.clear();
problem_sizes_host.clear();
return;
}
}
else {
randomize_problems(cmd);
}
}
void randomize_problems(cutlass::CommandLine &cmd) {
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes_after_alignment_host.reserve(groups);
problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
}
if (n < 1) {
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
}
if (k < 1) {
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
}
problem_sizes_after_alignment_host.push_back({m, n, k});
problem_sizes_host.push_back({m, n, k});
}
}
/// Load a benchmark
bool benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file.good()) {
return false;
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
cutlass::gemm::GemmCoord extent_after_alignment, extent;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
extent.at(i) = x;
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent_after_alignment.at(i) = x;
}
problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()});
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
groups = static_cast<int>(problem_sizes_after_alignment_host.size());
return true;
}
/// Calculate memory bandwidth statistics
template <class ElementA,
class ElementB,
class ElementC,
class ElementD,
class ElementBlockScale,
class TileShape,
int ScaleMsPerTile,
int ScaleNsPerTile>
auto gbps(double runtime_s) const {
double total_read_bytes = 0;
double total_write_bytes = 0;
// Calculate bytes read and written for each problem
for (int i = 0; i < groups; ++i) {
auto problem = problem_sizes_host.at(i);
auto M = cute::get<0>(problem);
auto N = cute::get<1>(problem);
auto K = cute::get<2>(problem);
if (M > 0) { // Only count active problems
// Matrix A: M*K elements read
total_read_bytes += M * K * sizeof(ElementA);
// Matrix B: K*N elements read
total_read_bytes += K * N * sizeof(ElementB);
// Matrix C: M*N elements read (for beta operation)
total_read_bytes += M * N * sizeof(ElementC);
// Block scales for A and B
auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
auto groupscale_m = blockscale_m * ScaleMsPerTile;
auto groupscale_n = blockscale_n * ScaleNsPerTile;
total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales
total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales
// Matrix D: M*N elements written
total_write_bytes += M * N * sizeof(ElementD);
}
}
return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s;
}
double bandwidth_util(double eff_bandwidth) const {
int memoryClockRate;
int memoryBusWidth;
cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0);
cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0);
double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6;
return eff_bandwidth / bw * 100.0;
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n"
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Number of real-valued multiply-adds
uint64_t fmas = 0ull;
for (auto const [m, n, k] : problem_sizes_host) {
fmas += static_cast<uint64_t>(m) *
static_cast<uint64_t>(n) *
static_cast<uint64_t>(k);
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};

View File

@ -374,7 +374,7 @@ void allocate(Options const& options) {
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
@ -510,7 +510,7 @@ void initialize(Options &options) {
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_B, seed + 2022);
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
@ -519,13 +519,13 @@ void initialize(Options &options) {
for (int32_t i = 0; i < options.groups; ++i) {
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream);
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream);
}
problem_sizes.reset(options.groups);
@ -619,7 +619,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
@ -676,6 +676,7 @@ bool verify(Options const& options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
// we don't swap and transpose in the verify so revert the problem shape.
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
@ -712,7 +713,7 @@ bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
}
}
return passed;

View File

@ -341,7 +341,7 @@ void allocate(Options const& options) {
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
@ -479,7 +479,7 @@ void initialize(Options& options) {
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_B, seed + 2022);
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
@ -565,7 +565,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
arguments = Args {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k},
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
@ -617,6 +617,7 @@ bool verify(Options const& options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
// we don't swap and transpose in the verify so revert the problem shape.
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
@ -630,11 +631,11 @@ bool verify(Options const& options) {
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
//
// Compute reference output
@ -659,7 +660,7 @@ bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
}
}
return passed;

View File

@ -282,7 +282,7 @@ void allocate(Options const& options) {
auto N = get<1>(problem);
auto K = get<2>(problem);
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
@ -418,7 +418,7 @@ void initialize(Options &options) {
beta_device.copy_from_host(ptr_beta_host.data());
initialize_tensor(block_A, seed + 2023);
initialize_quant_tensor(block_B, seed + 2022);
initialize_tensor(block_B, seed + 2022);
initialize_tensor(block_C, seed + 2021);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
@ -485,7 +485,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k},
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
@ -542,6 +542,7 @@ bool verify(Options const& options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
// we don't swap and transpose in the verify so revert the problem shape.
auto N = get<0>(problem);
auto M = get<1>(problem);
auto K = get<2>(problem);
@ -555,11 +556,11 @@ bool verify(Options const& options) {
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
const int scale_k = 1;
int const scale_k = cutlass::ceil_div(options.k, options.c);
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream);
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
//
// Compute reference output
@ -584,7 +585,7 @@ bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
std::cout << "Group: " << i << " Status: " << passed << std::endl;
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
}
}
return passed;

View File

@ -50,6 +50,7 @@ set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10)
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling
cutlass_example_add_executable(
69_hopper_mixed_dtype_grouped_gemm
@ -69,6 +70,7 @@ cutlass_example_add_executable(
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
)
cutlass_example_add_executable(
@ -89,6 +91,7 @@ cutlass_example_add_executable(
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
)
cutlass_example_add_executable(
@ -109,4 +112,5 @@ cutlass_example_add_executable(
TEST_RANDOM_PERF_LARGE_GROUP
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
)

View File

@ -7,8 +7,40 @@ This example shows how to perform Grouped GEMMs on Hopper when A and B have diff
- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B.
- if scales and zero-points are included, also pass the array of their strides in the arguments.
Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs.
Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling.
## Upcoming features
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed.
## Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -58,6 +58,7 @@ public:
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
cmd.get_cmd_line_argument("c", c);
MixedDtypeOptions::parse(argc, args);
@ -71,6 +72,7 @@ public:
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --c=<int> Sets the chunk size for scaling the quantized weights\n"
<< " --groups=<int> Sets the number of individual GEMM problems\n"
<< " --mode=<int> The mode to run the gemm\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
@ -183,11 +185,6 @@ void grouped_mixed_dtype_profiling(
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Sizes, Alpha, Beta\n";
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n';
}
std::cout << " Groups : " << options.groups << '\n'
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
<< " GFLOPS : " << result.gflops << '\n';

View File

@ -194,12 +194,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle;
Options():
help(false),
m(8192), n(8192), k(8192),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -217,6 +219,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -231,6 +234,7 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
@ -331,6 +335,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -231,6 +231,7 @@ struct Options {
bool save_amax = true;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
int swizzle = 0;
// Parses the command line
void parse(int argc, char const **args) {
@ -256,6 +257,7 @@ struct Options {
cmd.get_cmd_line_argument("save_aux", save_aux, true);
cmd.get_cmd_line_argument("save_amax", save_amax, true);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -271,6 +273,7 @@ struct Options {
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --scale_a=<f32> Scaling factor for A\n"
<< " --scale_b=<f32> Scaling factor for B\n"
<< " --scale_c=<f32> Scaling factor for C\n"
@ -476,6 +479,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
fusion_args.amax_D_ptr = abs_max_D.device_data();
}
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -28,14 +28,29 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
set(TEST_SWIZZLE_1 --swizzle=1)
set(TEST_SWIZZLE_2 --swizzle=2)
set(TEST_SWIZZLE_5 --swizzle=5)
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
cutlass_example_add_executable(
70_blackwell_fp16_gemm
70_blackwell_fp16_gemm.cu
)
TEST_COMMAND_OPTIONS
TEST_SWIZZLE_1
TEST_SWIZZLE_2
TEST_SWIZZLE_5
TEST_SWIZZLE_5_UNEVEN
)
cutlass_example_add_executable(
70_blackwell_fp8_gemm
70_blackwell_fp8_gemm.cu
TEST_COMMAND_OPTIONS
TEST_SWIZZLE_1
TEST_SWIZZLE_2
TEST_SWIZZLE_5
TEST_SWIZZLE_5_UNEVEN
)
endif()

View File

@ -74,12 +74,14 @@ struct Options {
int m, n, k, l;
float alpha, beta;
int swizzle;
Options():
help(false),
error(false),
m(2048), n(2048), k(2048), l(1),
alpha(1.f), beta(0.f)
alpha(1.f), beta(0.f),
swizzle(0)
{ }
// Parses the command line
@ -97,6 +99,7 @@ struct Options {
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -112,7 +115,8 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n";
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n\n";
return out;
}
@ -352,6 +356,8 @@ struct ExampleRunner {
hw_info
};
arguments.scheduler.max_swizzle_size = options.swizzle;
// See example 48 for details on custom EVT construction
if constexpr (UseCustomEVT) {
arguments.epilogue.thread =

View File

@ -211,12 +211,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -234,6 +236,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -247,7 +250,8 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
@ -333,7 +337,7 @@ bool initialize_block(
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
@ -344,8 +348,8 @@ void initialize(const Options &options) {
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
@ -387,6 +391,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -177,7 +177,7 @@ using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using FusionOp = typename Gemm::EpilogueOutputOp;
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig<OutputSFVectorSize>;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
@ -240,12 +240,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -263,6 +265,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -276,7 +279,8 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
@ -362,9 +366,9 @@ bool initialize_block(
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// For SFD tensor layout
using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
@ -375,8 +379,8 @@ void initialize(const Options &options) {
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
@ -432,6 +436,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
}
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
@ -475,7 +480,12 @@ bool verify(const Options &options) {
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
block_SFD.sync_host();
bool passed_sfd = cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view());
passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0);
passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0);
return passed && passed_sfd;
}
/// Execute a given example GEMM computation

View File

@ -212,12 +212,14 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
iterations(10),
swizzle(0)
{ }
// Parses the command line
@ -235,6 +237,7 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
@ -248,7 +251,8 @@ struct Options {
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
@ -334,7 +338,7 @@ bool initialize_block(
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
@ -345,8 +349,8 @@ void initialize(const Options &options) {
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
@ -388,6 +392,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -214,7 +214,8 @@ struct Options {
int iterations;
int m, n, k;
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
int swizzle = 0;
Options():
help(false),
m(4096), n(4096), k(4096),
@ -223,7 +224,8 @@ struct Options {
preferred_cluster_m(4),
preferred_cluster_n(4),
fallback_cluster_m(2),
fallback_cluster_n(1)
fallback_cluster_n(1),
swizzle(0)
{ }
// Parses the command line
@ -245,6 +247,7 @@ struct Options {
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
cmd.get_cmd_line_argument("swizzle", swizzle);
if (!validate_cluster_shape()){
std::cout << "--Invalid cluster shapes" << std::endl;
@ -265,6 +268,7 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
@ -384,7 +388,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}

View File

@ -543,7 +543,7 @@ int run(Options &options) {
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -560,7 +560,6 @@ int main(int argc, char const **args) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
//
// Parse options
//

View File

@ -237,11 +237,12 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {
bool help = false;
bool use_pdl = false;
float alpha = FLT_MAX;
float beta = FLT_MAX;
@ -264,6 +265,9 @@ struct Options {
help = true;
return;
}
if (cmd.check_cmd_line_flag("use_pdl")) {
use_pdl = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
@ -350,19 +354,9 @@ struct Options {
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
extent.at(i) = std::atoi(tokens.at(i).c_str());
}
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
groups = static_cast<int>(problem_sizes_host.size());
@ -387,7 +381,8 @@ struct Options {
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size\n"
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n";
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
out
<< "\n\nExamples:\n\n"
@ -711,7 +706,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
@ -730,7 +725,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
}
timer.stop();
@ -740,7 +735,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
}
return 0;

View File

@ -124,6 +124,7 @@ constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // A
using ElementAccumulator = float; // Element type for internal accumulation
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands
constexpr int OutputSFVectorSize = 16;
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
@ -219,14 +220,14 @@ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
OutputSFVectorSize,
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
>;
using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom;
using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF;
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
// Host-side allocations
std::vector<StrideA> stride_A_host;
@ -299,12 +300,13 @@ auto make_iterator(T* ptr) {
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {
bool help = false;
bool verification = true;
bool use_pdl = false;
float alpha = FLT_MAX;
float beta = FLT_MAX;
@ -328,9 +330,12 @@ struct Options {
help = true;
return;
}
if (cmd.check_cmd_line_flag("no-verif")) {
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
if (cmd.check_cmd_line_flag("use_pdl")) {
use_pdl = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
@ -418,19 +423,9 @@ struct Options {
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
extent.at(i) = std::atoi(tokens.at(i).c_str());
}
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
groups = static_cast<int>(problem_sizes_host.size());
@ -457,7 +452,8 @@ struct Options {
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size\n"
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
<< " --no-verif Do not run (host-side) verification kernels\n";
<< " --no_verif Do not run (host-side) verification kernels\n"
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
out
<< "\n\nExamples:\n\n"
@ -554,9 +550,9 @@ void allocate(const Options &options) {
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
stride_A_host.push_back(stride_A);
stride_B_host.push_back(stride_B);
@ -775,9 +771,9 @@ bool verify(const Options &options) {
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
@ -845,7 +841,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
cudaDeviceSynchronize();
@ -870,7 +866,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
}
timer.stop();
@ -880,7 +876,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
}
return 0;

View File

@ -490,7 +490,7 @@ int run(Options &options)
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -490,7 +490,7 @@ int run(Options &options)
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -499,11 +499,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
//
// Parse options
//

View File

@ -67,9 +67,6 @@
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
*/
#define DSHOW(x) print(#x ": "); print(x); print("\n");
#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n");
#include <iostream>
#include <random>
#include <regex>
@ -119,16 +116,20 @@ struct Options {
int h_k = 1;
int q = 256;
int k = 256;
std::vector<int> varlen_q;
std::vector<int> varlen_k;
int d = 128;
int warmup_iterations = 1;
int iterations = 3;
int tensor_ring_buffers = 1;
bool verify = false;
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
bool persistent = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
@ -182,20 +183,87 @@ struct Options {
cmd.get_cmd_line_argument("h_k", h_k, -1);
if (h_k == -1) h_k = h;
varlen = cmd.check_cmd_line_flag("varlen");
cmd.get_cmd_line_argument("q", q, -1);
cmd.get_cmd_line_argument("k", k, -1);
cmd.get_cmd_line_argument("b", b, -1);
std::string varlen_q_str;
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
std::string varlen_k_str;
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
if (varlen && ! varlen_q_str.empty()) {
varlen_q.clear();
while (! varlen_q_str.empty()) {
size_t pos = varlen_q_str.find(':');
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_q_str = varlen_q_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_q.size());
}
if (b != static_cast<int>(varlen_q.size())) {
std::cout << "Error: Invalid --varlen-q length\n";
std::exit(-1);
}
int new_q = 0;
for (auto elem : varlen_q) {
new_q += elem;
}
if (q != -1) {
std::cout << "Error: Can't provide --q and --varlen-q\n";
std::exit(-1);
}
q = new_q;
}
if (varlen && ! varlen_k_str.empty()) {
varlen_k.clear();
while (! varlen_k_str.empty()) {
size_t pos = varlen_k_str.find(':');
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_k_str = varlen_k_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_k.size());
}
if (b != static_cast<int>(varlen_k.size())) {
std::cout << " Error: Invalid --varlen-k length\n";
std::exit(-1);
}
int new_k = 0;
for (auto elem : varlen_k) {
new_k += elem;
}
if (k != -1) {
std::cout << "Error: Can't provide --k and --varlen-k\n";
std::exit(-1);
}
k = new_k;
}
if (q == -1) q = k;
if (k == -1) k = q;
if (q == -1 && k == -1) q = k = defaults.q;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "no" || mask == "") {
@ -213,7 +281,6 @@ struct Options {
causal = false;
}
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
@ -237,18 +304,23 @@ struct Options {
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
<< " --d=<int> Sets the D extent\n"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
<< " with the last batch sized to make it fit\n"
<< " implies at least residual masking for correctness\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
return out;
@ -382,40 +454,55 @@ struct FwdRunner {
StrideLSE stride_LSE;
uint64_t seed = 0;
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
struct DeviceBuffer {
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
DeviceBuffer() = default;
DeviceBuffer(const DeviceBuffer&) = delete;
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
size_t get_storage_size() const {
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
+ device_cumulative_seqlen_kv.get_storage_size();
}
};
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
std::vector<int> cumulative_seqlen_q;
std::vector<int> cumulative_seqlen_kv;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape) {
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
select<0,2,3>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
@ -434,7 +521,7 @@ struct FwdRunner {
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
@ -442,20 +529,22 @@ struct FwdRunner {
<< " mean " << mean_diff << std::endl;
}
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
bool passed_LSE = true; // future work
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
// if ( ! passed_LSE) {
// std::cerr << "failed LSE: max diff " << max_diff
// << " mean " << mean_diff << std::endl;
// }
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
return passed_O && passed_LSE;
}
template<class ProblemShape>
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
auto initialize_varlen(
const Options& options, const ProblemShape& problem_size,
const bool kVarlenSame = true) {
int num_batches = get<3,1>(problem_size);
// generate Q as --b times
@ -483,8 +572,12 @@ struct FwdRunner {
int max_seqlen_kv = 0;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
kVarlenSame ? get<0>(problem_size) :
generate_positive_int(dist_q, rng);
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
kVarlenSame ? get<1>(problem_size) :
generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
@ -525,7 +618,7 @@ struct FwdRunner {
decltype(problem_shape_in) problem_size;
if constexpr (kIsVarlen) {
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(options, problem_shape_in);
problem_shape = problem_shape_launch;
problem_size = problem_shape_init;
}
@ -562,50 +655,72 @@ struct FwdRunner {
get<1,1>(stride_LSE) = 0;
}
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_LSE.reset(size(shape_LSE));
block_ref_O.reset(size(shape_QO));
block_ref_LSE.reset(size(shape_LSE));
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_ref_LSE.reset(size(shape_LSE));
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v);
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
if ( ! cumulative_seqlen_q.empty()) {
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
if ( ! cumulative_seqlen_q.empty()) {
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
buffer.device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if ( ! cumulative_seqlen_kv.empty()) {
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
buffer.device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
};
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
int tensor_ring_buffers = options.tensor_ring_buffers;
for (int i = 1; i < tensor_ring_buffers; i++) {
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
}
if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
}
return problem_shape;
}
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
}
typename Operation::Arguments arguments{
problem_shape_,
{ buffers[buffer_index]->block_Q.get(), stride_Q,
buffers[buffer_index]->block_K.get(), stride_K,
buffers[buffer_index]->block_V.get(), stride_V },
{ buffers[buffer_index]->block_O.get(), stride_O,
buffers[buffer_index]->block_LSE.get(), stride_LSE },
hw_info
};
return arguments;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_shape = initialize(options);
typename Operation::Arguments arguments{
problem_shape,
{ block_Q.get(), stride_Q,
block_K.get(), stride_K,
block_V.get(), stride_V },
{ block_O.get(), stride_O,
block_LSE.get(), stride_LSE },
hw_info
};
int buffer_index = 0;
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
Operation op;
@ -633,11 +748,21 @@ struct FwdRunner {
}
// Run
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
for (int i = 0; i < options.warmup_iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
}
cudaError_t result = cudaDeviceSynchronize();
@ -675,6 +800,14 @@ struct FwdRunner {
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
}
//
@ -737,10 +870,10 @@ struct FwdRunner {
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape);
passed = verify(problem_shape, *buffers[0]);
if (passed) example_result.verified = true;
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
return example_result;
@ -755,11 +888,18 @@ struct FwdRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
@ -792,10 +932,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
using HeadDim = _128;
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -821,10 +965,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _64;
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
}
@ -848,10 +996,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _32;
#ifdef FP8
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
if (options.persistent) {
// Persistent Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
}
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
#endif
}
@ -948,7 +1100,7 @@ int main_single(int argc, char const **args) {
});
#endif
return 0;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -956,8 +1108,6 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -984,7 +1134,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return result;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,865 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
This example showcases the use of CUTLASS to build backward fused
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
the NVIDIA Blackwell architecture.
Background and motivation
-------------------------
CUTLASS is a highly flexible library that provides open-source building blocks
for tensor core programming for GEMM or GEMM-like problems. Fused multi-head
attention (FMHA) is a foundational kernel for large language models (LLMs) since it
makes long sequence lengths feasible from a memory-usage perspective. It also
improves computational efficiency since it transforms an outer-product-like and
a matrix-vector-like GEMM into a fused operation with much higher arithmetic
intensity. For more details, see Dao et al, 2022; Dao, 2023.
Implementing this kernel in CUTLASS enabled easy customization and high
performance.
Introduction
------------
The example targets the NVIDIA Blackwell architecture, and takes advantage of
5th gen tensor cores and the Tensor Memory Accelerator (TMA), just like
GEMMs do. It provides a backward pass (often abbreviated
bwd in the code).
The code is structured into three layers: The runner (and the reference kernels)
takes care of initialization, measurement, and testing; the device layer
orchestrates kernel calls and partitions workspace; and the kernel layer (just
like the CUTLASS kernel layer.
Support
---------
We support fp16 and fp8 data types with a head dimension of 128.
Example usage:
$ ./examples/77_blackwell_fmha/77_blackwell_fmha_bwd_fp16 \
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
*/
#include <iostream>
#include <random>
#include <regex>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "reference/fmha_fwd_reference.hpp"
#include "reference/fmha_bwd_reference.hpp"
#include "reference/reference_abs_error.hpp"
#include "collective/fmha_fusion.hpp"
#include "device/fmha_device_bwd.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
using namespace cute;
using namespace cutlass::fmha::kernel;
using namespace cutlass::fmha::collective;
using namespace cutlass::fmha;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle {
kOne, kZero, kLinearStride128, kLinearStride1, kRandom, kNone
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help = false;
bool error = false;
int b = 16;
int h = 16;
int h_k = 1;
int q = 1024;
int k = 1024;
int d = 128;
int iterations = 3;
bool verify = false;
bool verbose = false;
bool causal = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
InitStyle init_style_k = InitStyle::kRandom;
InitStyle init_style_v = InitStyle::kRandom;
InitStyle init_style_do = InitStyle::kRandom;
bool skip_reference = false;
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
std::string s;
cmd.get_cmd_line_argument(name, s, s);
if (s.empty()) {
dst = src;
}
else {
if (s == "r") {
dst = InitStyle::kRandom;
}
else if (s == "0") {
dst = InitStyle::kZero;
}
else if (s == "1") {
dst = InitStyle::kOne;
}
else if (s == "d") {
dst = InitStyle::kLinearStride1;
}
else if (s == "s") {
dst = InitStyle::kLinearStride128;
}
else if (s == "n") {
dst = InitStyle::kNone;
}
else {
std::cout << "Error: " << s << " is not a valid input type.\n";
std::exit(-1);
}
}
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
Options defaults;
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("d", d, defaults.d);
cmd.get_cmd_line_argument("h", h, -1);
if (h == -1) h = 2048 / d;
cmd.get_cmd_line_argument("q", q, -1);
cmd.get_cmd_line_argument("k", k, -1);
if (q == -1) q = k;
if (k == -1) k = q;
if (q == -1 && k == -1) q = k = defaults.q;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "causal") {
causal = true;
}
else {
causal = defaults.causal;
}
skip_reference = cmd.check_cmd_line_flag("skip-reference");
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_k);
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_v);
get_init_style_argument(cmd, "init-style", init_style_do, defaults.init_style_do);
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k);
get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v);
get_init_style_argument(cmd, "init-style-do", init_style_v, init_style_do);
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "77_blackwell_fmha_bwd\n\n"
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
<< " fused multi-head attention kernels for the backward pass targeting NVIDIA's Blackwell architecture.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the B extent\n"
<< " --h=<int> Sets the H extent\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|causal> Enables masking\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
void initialize_block(
DeviceAllocation<Element>& block,
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
switch (init_style) {
case InitStyle::kOne: {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) 1, (Element) 1);
break;
}
case InitStyle::kZero: {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) 0, (Element) 0);
break;
}
case InitStyle::kRandom: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) 0, (Element) 1);
break;
}
case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {
for (int j = 0; j < 128; j++) {
data[j + 128*i] = static_cast<Element>((double) (j % 4));
}
}
block.copy_from_host(data.data(), data.size());
break;
}
case InitStyle::kLinearStride128: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {
for (int j = 0; j < 128; j++) {
data[j + 128*i] = static_cast<Element>((double) (i % 4));
}
}
block.copy_from_host(data.data(), data.size());
break;
}
case InitStyle::kNone: {
break;
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleResult {
bool passed = false;
bool verified = false;
float runtime_ms = 0;
double tflops_tc_s = 0;
size_t smem_size = 0;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
template<
class TileShape,
class DispatchPolicy,
class ActiveMask,
class... KernelOptions
>
struct BwdRunner {
#ifdef FP8
using Element = cutlass::float_e4m3_t;
#else
using Element = cutlass::half_t;
#endif
using ElementAccumulator = float;
// Q K D (H B)
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
using StrideQ = TensorStride;
using StrideK = TensorStride;
using StrideV = TensorStride;
using StrideO = TensorStride;
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
// Backwards specific
using StrideDQ = TensorStride;
using StrideDK = TensorStride;
using StrideDV = TensorStride;
using StrideDO = TensorStride;
//
// Data members
//
/// Initialization
StrideQ stride_Q;
StrideK stride_K;
StrideV stride_V;
StrideO stride_O;
StrideLSE stride_LSE;
StrideDQ stride_dQ;
StrideDK stride_dK;
StrideDV stride_dV;
StrideDO stride_dO;
uint64_t seed = 0;
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_K;
DeviceAllocation<Element> block_V;
DeviceAllocation<Element> block_O;
DeviceAllocation<ElementAccumulator> block_LSE;
DeviceAllocation<Element> block_dQ;
DeviceAllocation<Element> block_dK;
DeviceAllocation<Element> block_dV;
DeviceAllocation<Element> block_dO;
DeviceAllocation<Element> block_ref_dQ;
DeviceAllocation<Element> block_ref_dK;
DeviceAllocation<Element> block_ref_dV;
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape) {
auto [Q, K, D, HB] = problem_shape;
auto [H, B] = HB;
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,2,3>(problem_shape),
stride_O);
// keep going here! (this might be better in cursor)
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
select<0,2,3>(problem_shape),
stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
select<1,2,3>(problem_shape),
stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
select<1,2,3>(problem_shape),
stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
select<0,2,3>(problem_shape),
stride_dO);
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-0 : 1e-2;
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(block_dQ, block_ref_dQ, max_diff, mean_diff);
bool passed_dQ = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_dQ) {
std::cerr << "failed dQ: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
reference_abs_diff(block_dK, block_ref_dK, max_diff, mean_diff);
bool passed_dK = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_dK) {
std::cerr << "failed dK: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
reference_abs_diff(block_dV, block_ref_dV, max_diff, mean_diff);
bool passed_dV = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_dV) {
std::cerr << "failed dV: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
return passed_dQ && passed_dK && passed_dV;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
auto [Q, K, D, HB] = problem_shape;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto shape_QO = select<0,2,3>(problem_shape);
auto shape_KV = select<1,2,3>(problem_shape);
auto shape_LSE = select<0,3>(problem_shape);
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
stride_V = stride_K;
stride_O = stride_Q;
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
stride_dQ = stride_Q;
stride_dK = stride_K;
stride_dV = stride_V;
stride_dO = stride_O;
auto lsize = [](auto shape) {
return size(make_shape(1ull, shape));
};
block_Q.reset(lsize(shape_QO));
block_K.reset(lsize(shape_KV));
block_V.reset(lsize(shape_KV));
block_O.reset(lsize(shape_QO));
block_LSE.reset(lsize(shape_LSE));
block_dQ.reset(lsize(shape_QO));
block_dK.reset(lsize(shape_KV));
block_dV.reset(lsize(shape_KV));
block_dO.reset(lsize(shape_QO));
block_ref_dQ.reset(lsize(shape_QO));
block_ref_dK.reset(lsize(shape_KV));
block_ref_dV.reset(lsize(shape_KV));
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v);
initialize_block(block_dO, seed + 2020, options.init_style_do);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,2,3>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
if (! options.skip_reference) {
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
}
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
initialize(problem_shape, options);
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
typename Operation::Arguments arguments{
problem_shape,
block_Q.get(), stride_Q,
block_K.get(), stride_K,
block_V.get(), stride_V,
block_O.get(), stride_O,
block_LSE.get(), stride_LSE,
block_dO.get(), stride_dO,
block_dQ.get(), stride_dQ,
block_dK.get(), stride_dK,
block_dV.get(), stride_dV,
softmax_scale,
hw_info
};
Operation op;
ExampleResult example_result;
example_result.smem_size = Operation::Kernel::SharedStorageSize;
size_t workspace_size = 0;
workspace_size = Operation::get_workspace_size(arguments);
DeviceAllocation<uint8_t> workspace(workspace_size);
cutlass::Status status = cutlass::Status::kSuccess;
status = op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
status = op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
// Run
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return example_result;
}
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result = cudaEventCreate(&event);
if (result != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
}
// Record an event at the start of a series of GEMMs
result = cudaEventRecord(events[0]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
for (int i = 0; i < options.iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result = cudaEventRecord(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
// Wait for work on the device to complete.
result = cudaEventSynchronize(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
runtime_ms /= static_cast<float>(options.iterations);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= static_cast<double>(get<2>(problem_shape));
flops *= static_cast<double>(get<3,0>(problem_shape));
flops *= static_cast<double>(get<3,1>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
example_result.runtime_ms = runtime_ms;
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return example_result;
}
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape);
if (passed) example_result.verified = true;
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
return example_result;
}
example_result.passed = true;
return example_result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
if (verbose) {
std::cout << " t=" << result.runtime_ms << "ms, "
"smem=" << result.smem_size << "b" << std::endl;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct KernelCoop {};
//////////////////////////////////////////////////////////////////////////////////////////////////
template<class Mask>
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using HeadDim = _64;
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<class Mask>
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using HeadDim = _128;
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_single(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
std::cout
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
<< "(compute capability 100a) and CUDA 12.8 or greater.\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
if (options.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
else {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;
auto with_causal = [&](auto fn) {
if (options.causal) {
fn(CausalMask{});
}
else {
fn(NoMask{});
}
};
with_causal([&](auto fusion) {
if (options.d <= 64) {
run_bwd_64(fusion, options, hw_info);
}
else if (options.d <= 128) {
run_bwd_128(fusion, options, hw_info);
}
else {
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
}
});
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
auto arg = full_arguments[i];
size_t eq_pos = arg.find('=');
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
for (;;) {
size_t comma_pos = rest.find(',');
std::string current = rest.substr(0, comma_pos);
full_arguments[i] = prefix + current;
std::vector<const char*> next_args;
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
main(argc, next_args.data());
if (comma_pos == std::string::npos) break;
rest = rest.substr(comma_pos+1);
}
recursed = true;
break;
}
}
if (! recursed) {
main_single(argc, args);
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -689,11 +689,18 @@ struct ExampleRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
if (result.supported && ! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
@ -781,12 +788,17 @@ int main_single(int argc, char const **args) {
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
)
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
if (options.d == 128) {
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
}
else {
std::cout << "Head Dimension != 128 is not supported for the fmha_gen example\n";
}
#endif
return 0;
@ -797,8 +809,6 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -825,7 +835,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return result;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,839 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the
NVIDIA Blackwell Architecture.
*/
#include <iostream>
#include <random>
#include <regex>
#include <cmath>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "reference/fmha_mla_reference.hpp"
#include "reference/reference_abs_error.hpp"
#include "device/sm100_mla.hpp"
#include "kernel/sm100_mla_tile_scheduler.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////////
using namespace cute;
using namespace cutlass::fmha::kernel;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle {
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help = false;
bool error = false;
int b = 1;
int k = 256;
int split_kv = -1; // number of split along k dim.
bool is_var_split_kv = false;
int max_split_kv = 16;
int page = -1;
float spread = 0.2f;
int iterations = 3;
bool verify = false;
bool verbose = false;
int sm_count = 0;
std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom;
InitStyle init_style_c = InitStyle::kRandom;
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
std::string s;
cmd.get_cmd_line_argument(name, s, s);
if (s.empty()) {
dst = src;
}
else {
if (s == "r") {
dst = InitStyle::kRandom;
}
else if (s == "l") {
dst = InitStyle::kRandomLarge;
}
else if (s == "1") {
dst = InitStyle::kOne;
}
else if (s == "d") {
dst = InitStyle::kLinearStride1;
}
else if (s == "s") {
dst = InitStyle::kLinearStride128;
}
else if (s == "n") {
dst = InitStyle::kNone;
}
else {
std::cout << "Error: " << s << " is not a valid input type.\n";
std::exit(-1);
}
}
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
Options defaults;
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("k", k, -1);
if (k == -1) k = defaults.k;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
cmd.get_cmd_line_argument("page", page, defaults.page);
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
if (page == -1) {
is_var_split_kv = false;
}
cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv);
if (is_var_split_kv == true) {
split_kv = max_split_kv;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c);
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c);
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "77_blackwell_mla\n\n"
<< " This example showcases the use of CUTLASS for fused multi-head latent\n"
<< " attention kernels targeting NVIDIA's Blackwell architecture.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the B extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --page=<int> Enables paging and sets the page size\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --spread=<float> Relative spread away from K for paging\n"
<< " --split_kv=<int> Split KV factor\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
void initialize_block(
DeviceAllocation<Element>& block,
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
switch (init_style) {
case InitStyle::kOne: {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) 1, (Element) 1);
break;
}
case InitStyle::kRandom: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 1);
break;
}
case InitStyle::kRandomLarge: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 100);
break;
}
case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) {
for (int j = 0; j < 128; j++) {
data[j + 128*i] = static_cast<Element>((double) (j % 4));
}
}
block.copy_from_host(data.data(), data.size());
break;
}
case InitStyle::kLinearStride128: {
std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 64; i ++) {
for (int j = 0; j < 64; j++) {
data[j + 64*i] = static_cast<Element>((double) (i % 9));
}
}
block.copy_from_host(data.data(), data.size());
break;
}
case InitStyle::kNone: {
break;
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleResult {
bool passed = false;
bool verified = false;
float runtime_ms = 0;
double tflops_tc_s = 0;
double tbytes_s = 0;
size_t smem_size = 0;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
template<bool v>
struct IsPersistent {
static const bool value = v;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template<
class TileShape,
class PersistenceOption = IsPersistent<true>
>
struct Runner {
#ifdef FP8
using Element = cutlass::float_e4m3_t;
#elif FP16
using Element = cutlass::half_t;
#else
#error "Must either define FP8 or FP16"
#endif
using ElementAcc = float;
using ElementOut = cutlass::half_t;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;
// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B
using TileScheduler = std::conditional_t<
PersistenceOption::value,
Sm100MlaPersistentTileScheduler,
Sm100MlaIndividualTileScheduler
>;
using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler
>;
using Operation = cutlass::fmha::device::MLA<Kernel>;
//
// Data members
//
/// Initialization
StrideQ stride_Q_latent;
StrideK stride_C_latent;
StrideQ stride_Q_rope;
StrideK stride_K_rope;
StrideO stride_O;
StrideLSE stride_LSE;
StrideLSE stride_PT;
uint64_t seed = 0;
int page_size = -1;
int page_count = -1;
// We allocate Q and C as first latent, then rope
// This means that we offset the pointer by HeadDim_latent to get the rope
// portion
DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_C;
DeviceAllocation<ElementOut> block_O;
DeviceAllocation<int> block_seq;
DeviceAllocation<int> block_PT;
DeviceAllocation<int> block_split_kv;
DeviceAllocation<int> block_accum_split_len;
DeviceAllocation<ElementAcc> block_LSE;
DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAcc> block_ref_LSE;
ElementAcc scale;
//
// Methods
//
bool verify(const ProblemShape& problem_shape) {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
int page_K = K;
int page_B = B;
if (block_PT.get() != nullptr) {
page_K = page_size;
page_B = page_count;
}
Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()),
cute::make_tuple(H, D_latent, B),
stride_Q_latent);
Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent),
cute::make_tuple(H, D_rope, B),
stride_Q_rope);
Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()),
cute::make_tuple(page_K, D_latent, page_B),
stride_C_latent);
Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent),
cute::make_tuple(page_K, D_rope, page_B),
stride_K_rope);
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
cute::make_tuple(H, D_latent, B),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
cute::make_tuple(H, B),
stride_LSE);
Tensor mSeq = make_tensor(make_gmem_ptr(static_cast<int*>(block_seq.get())), make_shape(B));
Tensor mPT = make_tensor(make_gmem_ptr(static_cast<int*>(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT);
fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
std::cerr << "failed O: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
bool passed_LSE = true;
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if ( ! passed_LSE) {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
return passed_O && passed_LSE;
}
ProblemShape initialize(const Options& options) {
auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int D_non_latent = 128;
scale = static_cast<decltype(scale)>(1.0 / sqrt(1.0 * (D_non_latent + D_rope)));
// Shape (H, D, B)
stride_Q_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
stride_Q_rope = stride_Q_latent;
stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
stride_LSE = cute::make_tuple(_1{}, 0 + H);
block_Q.reset(static_cast<size_t>(options.b) * H * (D_latent + D_rope));
block_O.reset(static_cast<size_t>(options.b) * H * D_latent);
block_LSE.reset(static_cast<size_t>(options.b) * H);
block_ref_O.reset(static_cast<size_t>(options.b) * H * D_latent);
block_ref_LSE.reset(static_cast<size_t>(options.b) * H);
if (options.page == -1) {
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(options.k) * (D_latent + D_rope));
stride_K_rope = stride_C_latent;
block_C.reset(static_cast<size_t>(options.b) * options.k * (D_latent + D_rope));
}
else {
float spread = options.spread;
int max_K = static_cast<int>((1 + spread) * K);
int min_K = static_cast<int>((1 - spread) * K);
page_size = options.page;
page_count = B * ceil_div(max_K, page_size);
stride_PT = cute::make_stride(_1{}, page_count);
std::vector<int> host_seq(B);
std::vector<int> host_PT(page_count * B);
for (int i = 0; i < B; i++) {
int seq = min_K + rand() % (max_K - min_K + 1);
host_seq[i] = seq;
for (int j = 0; j < ceil_div(seq, page_size); j++) {
host_PT[page_count * i + j] = i + j * B;
}
}
block_seq.reset(host_seq.size());
block_seq.copy_from_host(host_seq.data(), host_seq.size());
block_PT.reset(host_PT.size());
block_PT.copy_from_host(host_PT.data(), host_PT.size());
get<1>(problem_shape) = max_K;
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, page_size * static_cast<int64_t>((D_latent + D_rope)));
stride_K_rope = stride_C_latent;
block_C.reset(page_count * page_size * static_cast<int64_t>((D_latent + D_rope)));
if (options.is_var_split_kv == true) {
std::vector<int> host_split_kv(B);
for(int i = 0; i < B; ++i) {
auto len = host_seq[i];
int split = ceil_div(options.max_split_kv, ceil_div(max_K, len));
host_split_kv[i] = split;
}
block_split_kv.reset(B);
block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size());
}
}
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_C, seed + 2022, options.init_style_c);
return problem_shape;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShape problem_shape = initialize(options);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
typename Operation::Arguments arguments{
problem_shape,
{ scale,
block_Q.get(), stride_Q_latent,
block_Q.get() + D_latent, stride_Q_rope,
block_C.get(), stride_C_latent,
block_C.get() + D_latent, stride_K_rope,
block_seq.get(),
block_PT.get(), stride_PT,
page_count, page_size},
{ block_O.get(),
stride_O,
block_LSE.get(),
stride_LSE},
hw_info,
options.split_kv,
options.is_var_split_kv ? block_split_kv.get() : nullptr
};
if (options.split_kv < 0 && !options.is_var_split_kv) {
Operation::set_split_kv(arguments);
}
Operation op;
ExampleResult example_result;
example_result.smem_size = Operation::Kernel::SharedStorageSize;
size_t workspace_size = 0;
workspace_size = Operation::get_workspace_size(arguments);
DeviceAllocation<uint8_t> workspace(workspace_size);
cutlass::Status status = cutlass::Status::kSuccess;
status = op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
status = op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
// Run
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return example_result;
}
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result = cudaEventCreate(&event);
if (result != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
}
// Record an event at the start of a series of GEMMs
result = cudaEventRecord(events[0]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
for (int i = 0; i < options.iterations; i++) {
status = op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result = cudaEventRecord(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
// Wait for work on the device to complete.
result = cudaEventSynchronize(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
return example_result;
}
runtime_ms /= static_cast<float>(options.iterations);
double flops = 1.0;
flops *= B;
flops *= K;
flops *= H;
flops *= 2.0;
flops *= (2.0 * D_latent + D_rope);
double bytes_q = sizeof(Element);
bytes_q *= B;
bytes_q *= H;
bytes_q *= (D_latent + D_rope);
double bytes_c = sizeof(Element);
bytes_c *= B;
bytes_c *= options.k; // K may be max_K here
bytes_c *= (D_latent + D_rope);
double bytes_o = sizeof(ElementOut);
bytes_o *= B;
bytes_o *= H;
bytes_o *= D_latent;
double bytes = bytes_q + bytes_c + bytes_o;
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
example_result.tbytes_s = tbytes_s;
example_result.runtime_ms = runtime_ms;
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return example_result;
}
// Verify that the result is correct
bool passed = true;
if (options.verify) {
passed = verify(problem_shape);
if (passed) example_result.verified = true;
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
return example_result;
}
example_result.passed = true;
return example_result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
if (verbose) {
std::cout << " t=" << result.runtime_ms * 1e3 << " us, "
"smem=" << result.smem_size << "b" << std::endl;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, const char* name, auto... kernel_options) {
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
return;
}
Runner<decltype(shape), decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
};
using NumHeads = _128;
using HeadDimLatent = _512;
using HeadDim = Shape<HeadDimLatent, _64>;
std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " ";
std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " ";
std::cout << "Q 1 K " << options.k << " Gen None ";
std::cout << "Split " << options.split_kv << " Gen None ";
std::cout << "#SM " << hw_info.sm_count << std::endl;
using Blocking = _128;
std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{});
std::string individual = " individual";
std::string persistent = " persistent";
#if FP8
name += " fp8";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
#elif FP16
name += " fp16";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
#endif
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_single(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
std::cout
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
if (options.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
else {
hw_info.sm_count = options.sm_count;
}
run_mla(options, hw_info);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
auto arg = full_arguments[i];
size_t eq_pos = arg.find('=');
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
for (;;) {
size_t comma_pos = rest.find(',');
std::string current = rest.substr(0, comma_pos);
full_arguments[i] = prefix + current;
std::vector<const char*> next_args;
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
main(argc, next_args.data());
if (comma_pos == std::string::npos) break;
rest = rest.substr(comma_pos+1);
}
recursed = true;
break;
}
}
if (! recursed) {
main_single(argc, args);
}
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -28,12 +28,14 @@
set_property(
SOURCE 77_blackwell_fmha.cu
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
set_property(
SOURCE 77_blackwell_fmha_gen.cu
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0")
SOURCE
77_blackwell_fmha.cu
77_blackwell_fmha_gen.cu
77_blackwell_mla.cu
77_blackwell_fmha_bwd.cu
PROPERTY
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
)
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
@ -41,65 +43,137 @@ set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --v
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
set(TEST_VARLEN_00 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_01 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_02 --verify --varlen --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_03 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_04 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_05 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_06 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_VARLEN_07 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_VARLEN_08 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_VARLEN_09 --verify --varlen --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_VARLEN_10 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify)
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")))
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
foreach(PREC fp8 fp16)
string(TOUPPER "${PREC}" PREC_MACRO)
cutlass_example_add_executable(
77_blackwell_fmha_${PREC}
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_VARLEN
TEST_HDIM64
TEST_GQA
TEST_VARLEN_00
TEST_VARLEN_01
TEST_VARLEN_02
TEST_VARLEN_03
TEST_VARLEN_04
TEST_VARLEN_05
TEST_VARLEN_06
TEST_VARLEN_07
TEST_VARLEN_08
TEST_VARLEN_09
TEST_VARLEN_10
TEST_VARLEN_11
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
cutlass_example_add_executable(
77_blackwell_fmha_gen_${PREC}
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
TEST_GEN_VARLEN
# TEST_GEN_HDIM64
TEST_GEN_GQA
TEST_GEN_REMAP
TEST_GEN_CACHEONLY
)
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
cutlass_example_add_executable(
77_blackwell_mla_2sm_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_mla_2sm_cpasync_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_${PREC}
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_VARLEN
)
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_sat_${PREC}
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_GEN_VARLEN
TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
endforeach()
# Add a target that builds all examples
add_custom_target(77_blackwell_fmha_all
DEPENDS
77_blackwell_fmha_fp8
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8)
cutlass_example_add_executable(
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8)
cutlass_example_add_executable(
77_blackwell_fmha_fp16
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_CAUSAL
# TEST_VARLEN
# TEST_HDIM64
# TEST_GQA)
)
target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
cutlass_example_add_executable(
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen_fp16
77_blackwell_fmha_gen.cu
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
# TEST_GEN_VARLEN
# TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
)
target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
endif()
77_blackwell_mla_2sm_fp8
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
)
endif()

View File

@ -21,3 +21,68 @@ To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easies
The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements.
It is well-suited for applying masks or activations.
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
# FMHA for Blackwell: Backward
This sample provides code for fused multi-head attention backward pass.
It supports HeadDims of 64 and 128, and fp8, fp16, and bf16 input data types.
The blocking in sequence length Q and K is 128, loads are done via TMA.
We support causal masking.
The structure of this code is very similar to the forward pass, and the techniques are analogous.
There are three kernels to compute backwards:
1. `FmhaKernelBwdSumOdO` to compute the sum of the outer product of O and dO.
3. `Sm100FmhaBwdKernelTmaWarpSpecialized` to compute the backward pass.
2. `FmhaKernelBwdConvert` to convert the dQ from fp32 to the final output precision.
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
# MLA Inference for Blackwell
This sample provides code for fused multi-head latent attention inference in
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
It supports fp16, bf16, and fp8 input and output types.
To accomodate the large output accumulator due to the large latent head dimension,
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
for support of any power-of-two page size less than or equal to 128.
With paging, the code also supports variable sequence length.
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel.
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -157,7 +157,8 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
}
template<class BlkCoord, class TileShape, class ProblemSize>

View File

@ -42,7 +42,7 @@ template<
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE // Q, B
class StrideLSE_ // Q, B
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -54,6 +54,8 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
struct TensorStorage {
@ -79,6 +81,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
template<class ProblemShape>
@ -110,7 +115,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
);
return {
tma_store_o
tma_store_o,
args.ptr_LSE,
args.dLSE
};
}
@ -119,6 +126,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(

View File

@ -505,12 +505,12 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
}
template<bool need_apply_mask, class Stage, class BlkCoord, class CountingTensor, class ProblemShape>
template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
CUTLASS_DEVICE auto
softmax_step(
float& row_max, float& row_sum,
Stage stage, bool final_call,
BlkCoord const& blk_coord, CountingTensor const& cS,
BlkCoord const& blk_coord, CoordTensor const& cS,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
const int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
@ -942,16 +942,21 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
}
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
@ -961,7 +966,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
@ -1060,13 +1065,30 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
@ -1083,6 +1105,21 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
@ -1092,6 +1129,85 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective

View File

@ -514,12 +514,12 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
}
template<bool need_apply_mask, class Stage, class BlkCoord, class CountingTensor, class ProblemShape>
template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
CUTLASS_DEVICE auto
softmax_step(
float& row_max, float& row_sum,
Stage stage, bool final_call,
BlkCoord const& blk_coord, CountingTensor const& cS,
BlkCoord const& blk_coord, CoordTensor const& cS,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
@ -831,7 +831,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// loop:
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
@ -917,7 +917,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;

View File

@ -170,8 +170,8 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
auto tSgQ = thr_mma_qk.partition_A(gQ);
auto tScQ = thr_mma_qk.partition_A(cQ);
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
auto tiled_copy_q = make_cotiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},

View File

@ -0,0 +1,92 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cuda_runtime.h>
namespace cutlass::fmha {
struct Pow2 {
int n;
int log2_n;
explicit CUTE_DEVICE Pow2(int n) : n(n) {
#ifdef __CUDA_ARCH__
log2_n = __ffs(n) - 1;
#endif
}
template<class T>
CUTE_HOST_DEVICE T operator *(T const& b) const {
return n * b;
}
template<int N>
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
if constexpr (N & (N - 1) == 0) {
return Pow2{n * N};
}
return n * N;
}
};
template<class T>
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
return a >> b.log2_n;
}
template<class T>
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
return a & (b.n - 1);
}
template<class T>
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
return a < b.n;
}
CUTE_HOST_DEVICE void print(Pow2 const& a) {
printf("2^%d", a.log2_n);
}
} // end namespace cutlass::fmha
namespace cute {
template <>
struct is_integral<cutlass::fmha::Pow2> : true_type {};
} // end namespace cute

View File

@ -0,0 +1,320 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/tensor.hpp"
#include "../device/fmha.hpp"
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<class Element, class ElementAccumulator, class TileShape, class Mask>
class Sm100FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
// Q K D HB
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
const Element* ptr_K;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
const Element* ptr_V;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
Element* ptr_dQ;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
Element* ptr_dK;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
Element* ptr_dV;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
ElementAccumulator softmax_scale;
cutlass::KernelHardwareInfo hw_info;
};
using OperationSumOdO = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
>;
using OperationConvert = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
>;
using Operation = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
>;
using Kernel = typename Operation::Kernel;
struct Params {
OperationSumOdO op_sum_OdO;
Operation op;
OperationConvert op_convert;
ElementAccumulator* dQ_acc;
size_t dQ_acc_size;
};
private:
Params params_;
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(
Arguments const& args,
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_size,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
sum_odo, stride_sum_OdO,
args.ptr_LSE, args.stride_LSE,
scaled_lse, stride_scaled_lse,
-1.0f, -log2_e
};
}
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
return typename OperationConvert::Arguments {
args.problem_size,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
args.softmax_scale
};
}
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_size,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
scaled_lse, stride_scaled_lse,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ,
args.softmax_scale },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
args.hw_info
};
}
public:
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
Status status = Status::kSuccess;
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = OperationConvert::can_implement(to_convert_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = Operation::can_implement(to_bwd_arguments(args));
if (status != Status::kSuccess) {
return status;
}
return status;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
// scaled LSE vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
return workspace_bytes;
}
/// Initializes state from arguments.
Status
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
params_.op_convert.initialize(args_convert, nullptr, stream);
auto args_bwd = to_bwd_arguments(
args, sum_OdO, args_sum_OdO.stride_sum_OdO,
scaled_lse, args_sum_OdO.stride_scaled_lse,
dQ_acc, args_convert.stride_src_dQ
);
params_.op.initialize(args_bwd, nullptr, stream);
return Status::kSuccess;
}
/// Initializes state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
Status result = Status::kSuccess;
result = params.op_sum_OdO.run(stream);
if (result != Status::kSuccess) {
return result;
}
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
if (cuda_result != cudaSuccess) {
return Status::kErrorInternal;
}
result = params.op.run(stream);
if (result != Status::kSuccess) {
return result;
}
result = params.op_convert.run(stream);
if (result != Status::kSuccess) {
return result;
}
return Status::kSuccess;
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,357 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
using namespace cute;
using namespace cutlass::fmha::kernel;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class Kernel_
>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
/// Argument structure: User API
using KernelArguments = typename Kernel::Arguments;
using ReductionArguments = typename ReductionKernel::Arguments;
using Arguments = KernelArguments;
/// Argument structure: Kernel API
using KernelParams = typename Kernel::Params;
using ReductionParams = typename ReductionKernel::Params;
struct Params {
KernelParams fmha_params;
ReductionParams reduction_params;
};
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
static ReductionArguments to_reduction_args(Arguments const& args) {
auto [H, K, D, B] = args.problem_shape;
return ReductionArguments{
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
args.ptr_split_kv, Kernel::TileShapeS::value
};
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
static void set_split_kv (KernelArguments& args) {
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
int sm_count = args.hw_info.sm_count;
int max_splits = ceil_div(K, 128);
int sms_per_batch = max(1, sm_count / B);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (! Kernel::can_implement(args)) {
return Status::kInvalid;
}
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
return Status::kInvalid;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
return workspace_bytes;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
if (status != Status::kSuccess) {
return status;
}
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {kernel_params, reduction_params};
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {fmha_params, reduction_params};
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params.fmha_params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess != result or Status::kSuccess != launch_result) {
//return Status::kSuccess;
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
else {
return Status::kSuccess;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,146 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class Element, class ElementAcc>
struct FmhaKernelBwdConvert {
struct Arguments {
tuple<int, int, int, tuple<int, int>> problem_size;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
const ElementAcc* ptr_src_dK;
tuple<int, _1, tuple<int, int>> stride_src_dK;
const ElementAcc* ptr_src_dV;
tuple<int, _1, tuple<int, int>> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, _1, tuple<int, int>> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, _1, tuple<int, int>> stride_dest_dV;
ElementAcc scale = 1.0;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static const int kBlockSeq = 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kNumThreadsD = 16;
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 4;
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
template<class StrideSrc, class StrideDest>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= count) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAcc value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAcc> * kElementsPerLoad>;
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
value_dest[v] = static_cast<Element>(params.scale * value_src[v]);
}
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
}
}
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,151 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class Element, class ElementAcc>
struct FmhaKernelBwdSumOdO {
struct Arguments {
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
ElementAcc* ptr_sum_OdO;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
const ElementAcc* ptr_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
ElementAcc* ptr_scaled_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
ElementAcc sum_odo_scale = 1.0;
ElementAcc lse_scale = 1.0;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm100;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kBlockQ = 16;
static const int kNumThreadsD = 8;
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 2;
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<2>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= get<0>(params.problem_size)) continue;
ElementAcc acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO);
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
acc += value_O[v] * value_dO[v];
}
}
for (int i = 1; i < kNumThreadsD; i *= 2) {
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
}
if (threadIdx.x == 0) {
*ptr_sum_OdO_bhq = params.sum_odo_scale * acc;
if (params.ptr_scaled_lse) {
*ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq;
}
}
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -90,8 +90,8 @@ struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_h;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
@ -146,7 +146,7 @@ struct PersistentTileScheduler {
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE

File diff suppressed because it is too large Load Diff

View File

@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue;
CollectiveEpilogue epilogue{params.epilogue};
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>();
@ -372,6 +372,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
bool is_softmax_0 = role == WarpRole::Softmax0;
mainloop.softmax(
@ -400,17 +404,30 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
continue;
}
mainloop.correction(
blk_coord,
params.mainloop, logical_problem_shape,
params.problem_shape,
shared_storage.epilogue,
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state
pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
);
}
if constexpr (NumWarpsEpilogue == 0) {
@ -438,6 +455,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.mma(
blk_coord,
@ -450,7 +470,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_mma_corr, pipeline_mma_corr_producer_state
);
}
}
else if (role == WarpRole::Load) {
@ -467,6 +486,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
mainloop.load(
blk_coord, logical_problem_shape,
params.mainloop, params.problem_shape,

View File

@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
using Params = Arguments;
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
args.ptr_split_kv, args.tile_shape_s};
}
static size_t get_workspace_size(Arguments const& /*args*/) {
return 0;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return dim3(kNumHeads, 1, params.num_batches);
}
static dim3 get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
static bool can_implement(Arguments const& args) {
if (args.num_batches <= 0) return false;
if (args.split_kv <= 0) return false;
return true;
}
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
ElementAcc local_lse[kNLsePerThread];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
}
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
if (split < local_split_kv) {
sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
}
__syncthreads();
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
ElementAcc local_val[Elements] = {0};
for (int split = 0; split < local_split_kv; ++split) {
ElementAcc lse_scale = sLseScale[split];
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
}
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
}
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
}
}
};
} // namespace cutlass::fmha::kernel

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,160 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaIndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler(Params const&) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = size<0>(cluster_shape);
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
num_blocks *= split_kv; /* Maximum Split KV*/
return Params {
num_blocks,
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,311 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
class TensorDQ, /* class TensorDK, class TensorDV, */
class Fusion
>
void __global__ fmha_bwd_reference_dQ_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
Fusion fusion) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
} // for idx_D0
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_K] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_K
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
}
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
} // for idx_D
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/* class TensorDQ, */ class TensorDK, /* class TensorDV, */
class Fusion
>
void __global__ fmha_bwd_reference_dK_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
Fusion fusion) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
} // for idx_D0
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_Q
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
}
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
} // for idx_D
} // for idx_K
} // for idx_L
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/* class TensorDQ, class TensorDK, */ class TensorDV,
class Fusion
>
void __global__ fmha_bwd_reference_dV_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
Fusion fusion) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAcc softmax_scale = static_cast<ElementAcc>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
acc_qk += rQ * rK;
} // for idx_D0
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)));
} // for idx_Q
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
ElementAcc rS = mS[idx_Q];
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
acc += rS * rDO;
}
mDV(idx_K, idx_D, idx_L) = static_cast<typename TensorDV::value_type>(acc);
} // for idx_D
} // for idx_K
} // for idx_L
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/
class Fusion
>
void fmha_bwd_reference_dQ(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/
Fusion fusion) {
using namespace cute;
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/
class Fusion
>
void fmha_bwd_reference_dK(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/
Fusion fusion) {
using namespace cute;
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/
class Fusion
>
void fmha_bwd_reference_dV(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/
Fusion fusion) {
using namespace cute;
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
class TensorDQ, class TensorDK, class TensorDV,
class Fusion
>
void fmha_bwd_reference(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, TensorDK mDK, TensorDV mDV,
Fusion fusion) {
fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -80,6 +80,17 @@ void __global__ fmha_reference_kernel(
if constexpr (rank<1>(decltype(coord){}) == 2) {
offset_K = get<1,1>(coord);
}
if (get<1>(problem_shape) == 0) {
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0);
}
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
mLSE(idx_Q + offset_Q, idx_L) = -INFINITY;
}
continue;
}
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
ElementAccumulator acc = 0;
@ -127,8 +138,8 @@ void __global__ fmha_reference_kernel(
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
}
if (threadIdx.x == 0) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS;
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
}
}

View File

@ -0,0 +1,201 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorSeq,
class TensorPageTable,
class TensorQL,
class TensorQR,
class TensorCL,
class TensorKR,
class TensorO,
class TensorLSE,
class Scale
>
void __global__ fmha_mla_reference_kernel(
ProblemShape problem_shape,
TensorSeq mSeq, TensorPageTable mPT,
TensorQL mQL, TensorQR mQR,
TensorCL mCL, TensorKR mKR,
TensorO mO, TensorLSE mLSE,
Scale softmax_scale) {
using namespace cute;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
using Element = typename TensorO::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ ElementAcc mS[];
// ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) {
if (mSeq.data() != nullptr) {
K = mSeq(idx_B);
}
for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
ElementAcc acc = 0;
for (int idx_D = 0; idx_D < D_latent; idx_D++) {
int page_idx_K = idx_K;
int page_idx_B = idx_B;
if (mPT.data() != nullptr) {
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
page_idx_K = idx_K % size<0>(mCL);
}
ElementAcc eQ = mQL(idx_H, idx_D, idx_B);
ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B);
acc += eQ * eK;
}
for (int idx_D = 0; idx_D < D_rope; idx_D++) {
int page_idx_K = idx_K;
int page_idx_B = idx_B;
if (mPT.data() != nullptr) {
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
page_idx_K = idx_K % size<0>(mCL);
}
ElementAcc eQ = mQR(idx_H, idx_D, idx_B);
ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B);
acc += eQ * eK;
}
mS[idx_K] = acc;
}
__syncthreads();
ElementAcc maxS = -std::numeric_limits<ElementAcc>::infinity();
for (int idx_K = 0; idx_K < K; idx_K++) {
maxS = std::max<ElementAcc>(maxS, mS[idx_K]);
}
if (maxS == -std::numeric_limits<ElementAcc>::infinity()) maxS = 0;
__syncthreads();
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
}
__syncthreads();
ElementAcc sum = 0;
for (int idx_K = 0; idx_K < K; idx_K++) {
sum += mS[idx_K];
}
ElementAcc o_scale = 1.0f / sum;
for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_K = 0; idx_K < K; idx_K++) {
int page_idx_K = idx_K;
int page_idx_B = idx_B;
if (mPT.data() != nullptr) {
page_idx_B = mPT(idx_K / size<0>(mCL), idx_B);
page_idx_K = idx_K % size<0>(mCL);
}
ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B);
ElementAcc eK = static_cast<Element>(mS[idx_K]);
acc += eK * eV;
}
mO(idx_H, idx_D, idx_B) = static_cast<typename TensorO::value_type>(acc * o_scale);
}
if (threadIdx.x == 0) {
mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorSeq,
class TensorPageTable,
class TensorQL,
class TensorQR,
class TensorCL,
class TensorKR,
class TensorO,
class TensorLSE,
class Scale
>
void fmha_mla_reference(
ProblemShape problem_shape,
TensorSeq mSeq, TensorPageTable mPT,
TensorQL mQL, TensorQR mQR,
TensorCL mCL, TensorKR mKR,
TensorO mO, TensorLSE mLSE,
Scale scale) {
using namespace cute;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
dim3 grid(H, B, 1);
dim3 block(256);
int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16;
cudaError_t result;
if (shared_mem >= (48 << 10)) {
result = cudaFuncSetAttribute(
&fmha_mla_reference_kernel<ProblemShape, TensorSeq, TensorPageTable, TensorQL, TensorQR, TensorCL, TensorKR, TensorO, TensorLSE, Scale>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
throw std::runtime_error("couldn't perform smem optin");
}
}
fmha_mla_reference_kernel<<<grid, block, shared_mem>>>(
problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale);
cudaDeviceSynchronize();
result = cudaGetLastError();
if (cudaSuccess != result) {
throw std::runtime_error("couldn't execute reference");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -75,6 +75,8 @@ struct DeviceAllocation {
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
void copy_from_host(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess);
@ -99,8 +101,12 @@ __global__ void reference_abs_diff_kernel(
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
if (data[i] == data_ref[i]) {
continue;
}
double diff = fabs(data[i] - data_ref[i]);
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}
@ -178,3 +184,99 @@ void reference_abs_diff(
max_diff = result_host[0];
mean_diff = result_host[1] / static_cast<double>(data.size());
}
template<typename Element>
__global__ void reference_rel_diff_kernel(
Element* data, Element* data_ref, size_t count,
double* max_diff, double* sum_diff,
bool print_diff ) {
double thread_max_diff = 0;
double thread_sum_diff = 0;
__shared__ double block_max_diff;
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
if (data[i] == data_ref[i]) {
continue;
}
double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]);
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}
for (int i = 0; i < blockDim.x; i++) {
if (i == threadIdx.x) {
if (i == 0) {
block_max_diff = thread_max_diff;
block_sum_diff = thread_sum_diff;
}
else {
block_max_diff = fmax(block_max_diff, thread_max_diff);
block_sum_diff += thread_sum_diff;
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(sum_diff, block_sum_diff);
for (;;) {
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
double prev_diff = reinterpret_cast<double const&>(prev);
double new_max_diff = fmax(block_max_diff, prev_diff);
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
if (found == prev) break;
}
}
}
template<typename Element>
void reference_rel_diff(
DeviceAllocation<Element> const& data,
DeviceAllocation<Element> const& data_ref,
double& max_diff, double& mean_diff) {
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
DeviceAllocation<double> result;
result.reset(2);
assert(data.size() == data_ref.size());
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
if (err != cudaSuccess) {
std::cerr << "Memset failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
dim3 block(256, 1, 1);
dim3 grid(1024, 1, 1);
reference_rel_diff_kernel<<<block, grid>>>(
data.get(), data_ref.get(), data.size(),
result.get(), result.get() + 1, kPrintDiff);
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
std::cerr << "Difference kernel failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
double result_host[2];
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
if (err != cudaSuccess) {
std::cerr << "Copy failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
max_diff = result_host[0];
mean_diff = result_host[1] / static_cast<double>(data.size());
}

View File

@ -0,0 +1,546 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale).
NVFP4 MMA has 2x throughput compared to MXFP8 MMA and 4x throughput compared to Ada Tensor Core FP8 MMA.
(See https://docs.nvidia.com/cuda/parallel-thread-execution).
This kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Tensor Core MMA Instructions
4. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79a_blackwell_geforce_nvfp4_bf16_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,593 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture.
The kernel outputs quantized fp4 values with scale factors that will be the input of another GEMM.
This kernel is optimized for the GeForce RTX 50 series GPUs.
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Tensor Core MMA Instructions
4. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFD matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
constexpr int InputSFVectorSize = 16;
constexpr int OutputSFVectorSize = InputSFVectorSize;
// D = alpha * acc + beta * C
// With BlockScaleFactor generation.
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
OutputSFVectorSize,
ElementD,
ElementCompute,
ElementSFD, LayoutSFDTag,
ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong // Ping-pong kernel schedule policy.
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using FusionOp = typename Gemm::EpilogueOutputOp;
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutSFD layout_SFD;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
// Matrix-wide normalization constant
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79b_blackwell_geforce_nvfp4_nvfp4_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// For SFD tensor layout
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_Normconst.reset(cutlass::make_Coord(1));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_Normconst.at(cutlass::make_Coord(0)) = 2;
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
block_SFD.sync_device();
block_Normconst.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
if constexpr (IsBlockScaleSupported) {
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
}
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
auto tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D), // TensorD
decltype(tensor_SFD), // TensorSfD
cute::Int<OutputSFVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,546 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale).
MXFP8 MMA has 2x throughput compared to Ada Tensor Core FP8 MMA.
(See https://docs.nvidia.com/cuda/parallel-thread-execution).
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Tensor Core MMA Instructions
4. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float6_t<cutlass::float_e3m2_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79c_blackwell_geforce_mixed_mxfp8_bf16_gemm\n\n"
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,927 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped GEMM example using CUTLASS 3x APIs for the NVIDIA Blackwell SM120 architecture.
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM120 TensorOp-based warp-specialized kernel
for narrow precisions (FP4) with input Scale Factors.
For this example all scheduling work is performed on the device, utilizing the device-side modification of TMA descriptors
to move between groups/problem_count (represented by groups).
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
To run this example:
$ ./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Skipping any of the problem dimensions randomizes it across the different groups.
Same applies for alpha and beta values that are randomized across the different groups.
To run this example for a set of problems using the benchmark option:
$ ./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm --benchmark=./test_benchmark.txt
Where the test_benchmark.txt may look as such:
0 256x512x128
1 256x512x512
2 512x256x128
3 256x256x128
4 256x512x1024
5 1024x512x128 and so on
*/
#include <iostream>
#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>
#include <float.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "helper.h"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<ElementInput>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<ElementInput>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = float_e2m1_t; // Element type for D matrix operands
using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands
using ElementC = cutlass::half_t; // Element type for C matrix operands
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Alignment of D matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for internal computation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Epilogue Operator class tag
// Kernel Perf config
// Cluster Shape fixed to 1x1x1
using ThreadBlockShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_1,_1,_1>;
constexpr int OutputSFVectorSize = 16;
// D = alpha * acc + beta * C
// With BlockScaleFactor generation.
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
OutputSFVectorSize,
ElementD,
ElementCompute,
ElementSFD, LayoutCTag,
ElementC>;
// Cooperative kernel schedule
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag *, AlignmentC,
ElementD, LayoutCTag *, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto,
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag *, AlignmentA,
ElementB, LayoutBTag *, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Auto schedule defaults to cooperative schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Pingpong kernel schedule
using CollectiveMainloopPingpong = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag *, AlignmentA,
ElementB, LayoutBTag *, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong
>::CollectiveOp;
using GemmKernelPingpong = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopPingpong,
CollectiveEpilogue
>;
using GemmPingpong = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelPingpong>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
OutputSFVectorSize,
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
>;
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
// Host-side allocations
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFA> layout_SFB_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
using HostTensorA = cutlass::HostTensor<typename Gemm::ElementA, cutlass::layout::PackedVectorLayout>;
using HostTensorB = cutlass::HostTensor<typename Gemm::ElementB, cutlass::layout::PackedVectorLayout>;
using HostTensorSF = cutlass::HostTensor<typename Gemm::GemmKernel::CollectiveMainloop::ElementSF, cutlass::layout::PackedVectorLayout>;
using HostTensorC = cutlass::HostTensor<typename Gemm::ElementC, cutlass::layout::PackedVectorLayout>;
using HostTensorD = cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, cutlass::layout::PackedVectorLayout>;
std::vector<HostTensorA> block_A;
std::vector<HostTensorB> block_B;
std::vector<HostTensorSF> block_SFA;
std::vector<HostTensorSF> block_SFB;
std::vector<HostTensorC> block_C;
std::vector<HostTensorD> block_D;
std::vector<HostTensorSF> block_SFD;
std::vector<HostTensorD> block_ref_D;
std::vector<HostTensorSF> block_ref_SFD;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
cutlass::DeviceAllocation<const typename Gemm::GemmKernel::CollectiveMainloop::ElementSF *> ptr_SFA;
cutlass::DeviceAllocation<const typename Gemm::GemmKernel::CollectiveMainloop::ElementSF *> ptr_SFB;
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<typename Gemm::GemmKernel::CollectiveMainloop::ElementSF *> ptr_SFD;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
// Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
// A matrix wide constant value to scale the output matrix
// Avoids generating small FP4 values.
// NormConst is a single device-side constant value, its not per-batch or per-group
cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {
bool help = false;
bool verification = true;
bool use_pdl = false;
float alpha = std::numeric_limits<float>::max();
float beta = std::numeric_limits<float>::max();
float norm_constant = 1.0;
int iterations = 10;
int m = 1024, n = 2048, k = 512, groups = 10;
RasterOrderOptions raster_order = RasterOrderOptions::AlongN;
int max_sm_count = INT_MAX;
std::string benchmark_path;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
int const tma_alignment_bits = 128;
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementInput>::value;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
if (cmd.check_cmd_line_flag("use_pdl")) {
use_pdl = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, std::numeric_limits<float>::max());
cmd.get_cmd_line_argument("beta", beta, std::numeric_limits<float>::max());
cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0));
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX);
// Decide how to initialize the problems
if (!benchmark_path.empty()) {
if (!benchmark_problems()) {
problem_sizes_host.clear();
return;
}
}
else {
randomize_problems(cmd);
}
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') {
raster_order = RasterOrderOptions::AlongN;
}
else if (raster_char == 'M' || raster_char == 'm') {
raster_order = RasterOrderOptions::AlongM;
}
}
void randomize_problems(cutlass::CommandLine &cmd) {
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = alignment * ((rand() % 64) + 1);
}
if (n < 1) {
n = alignment * ((rand() % 64) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
}
problem_sizes_host.push_back({m, n, k});
}
}
/// Load a benchmark
bool benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file.good()) {
return false;
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
cutlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
}
groups = static_cast<int>(problem_sizes_host.size());
return true;
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79d_blackwell_geforce_nvfp4_grouped_gemm\n\n"
<< " Blackwell Block Scaled Narrow Precision Grouped GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --norm_constant=<f32> Epilogue scalar normalization constant for the output matrix\n\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size\n"
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
<< " --no_verif Do not run (host-side) verification kernels\n"
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "79d_blackwell_geforce_nvfp4_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
{
// Number of real-valued multiply-adds
uint64_t fmas = uint64_t();
for (auto const & problem : problem_sizes_host) {
fmas += static_cast<uint64_t>(get<0>(problem)) *
static_cast<uint64_t>(get<1>(problem)) *
static_cast<uint64_t>(get<2>(problem));
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Allocates device-side data
void allocate(const Options &options) {
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
stride_A_host.push_back(stride_A);
stride_B_host.push_back(stride_B);
layout_SFA_host.push_back(layout_SFA);
layout_SFB_host.push_back(layout_SFB);
stride_C_host.push_back(stride_C);
stride_D_host.push_back(stride_D);
block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A))));
block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B))));
block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA)))));
block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB)))));
block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C))));
block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD)))));
block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
block_ref_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD)))));
}
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<typename Gemm::ElementA *> ptr_A_host(options.groups);
std::vector<typename Gemm::ElementB *> ptr_B_host(options.groups);
std::vector<typename Gemm::GemmKernel::CollectiveMainloop::ElementSF *> ptr_SFA_host(options.groups);
std::vector<typename Gemm::GemmKernel::CollectiveMainloop::ElementSF *> ptr_SFB_host(options.groups);
std::vector<typename Gemm::ElementC *> ptr_C_host(options.groups);
std::vector<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D_host(options.groups);
std::vector<typename Gemm::GemmKernel::CollectiveMainloop::ElementSF *> ptr_SFD_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
initialize_block(block_A.at(i).host_view(), seed + 2021);
initialize_block(block_B.at(i).host_view(), seed + 2022);
initialize_block(block_C.at(i).host_view(), seed + 2023);
initialize_block(block_SFA.at(i).host_view(), seed + 2024);
initialize_block(block_SFB.at(i).host_view(), seed + 2025);
block_A.at(i).sync_device();
block_B.at(i).sync_device();
block_C.at(i).sync_device();
block_SFA.at(i).sync_device();
block_SFB.at(i).sync_device();
ptr_A_host.at(i) = block_A.at(i).device_data();
ptr_B_host.at(i) = block_B.at(i).device_data();
ptr_SFA_host.at(i) = block_SFA.at(i).device_data();
ptr_SFB_host.at(i) = block_SFB.at(i).device_data();
ptr_C_host.at(i) = block_C.at(i).device_data();
ptr_D_host.at(i) = block_D.at(i).device_data();
ptr_SFD_host.at(i) = block_SFD.at(i).device_data();
alpha_host.push_back((options.alpha == std::numeric_limits<float>::max()) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == std::numeric_limits<float>::max()) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_SFA.reset(options.groups);
ptr_SFA.copy_from_host(ptr_SFA_host.data());
ptr_SFB.reset(options.groups);
ptr_SFB.copy_from_host(ptr_SFB_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_SFD.reset(options.groups);
ptr_SFD.copy_from_host(ptr_SFD_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
layout_SFA.reset(options.groups);
layout_SFA.copy_from_host(layout_SFA_host.data());
layout_SFB.reset(options.groups);
layout_SFB.copy_from_host(layout_SFB_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
norm_constant_device.reset(1);
norm_constant_device.copy_from_host(&options.norm_constant);
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Gemm>
typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count);
typename Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
// If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
if (options.alpha != std::numeric_limits<float>::max()){
// Single alpha for all groups
fusion_args.alpha = options.alpha;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.dAlpha = {_0{}, _0{}, 0};
}
else {
fusion_args.alpha = 0;
fusion_args.alpha_ptr_array = alpha_device.get();
// Only one alpha per each group
fusion_args.dAlpha = {_0{}, _0{}, 1};
}
if (options.beta != std::numeric_limits<float>::max()) {
// Single beta for all groups
fusion_args.beta = options.beta;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dBeta = {_0{}, _0{}, 0};
}
else {
fusion_args.beta = 0;
fusion_args.beta_ptr_array = beta_device.get();
// Only one beta per each group
fusion_args.dBeta = {_0{}, _0{}, 1};
}
// Output Block SF
fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output
fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = options.raster_order;
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info, scheduler
};
}
else {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info, scheduler
};
}
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
bool passed = true;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
decltype(tensor_A),
decltype(tensor_SFA),
decltype(tensor_B),
decltype(tensor_SFB)
>
mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C);
auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D);
auto tensor_ref_SFD = cute::make_tensor(make_iterator(block_ref_SFD.at(i).host_data()), layout_SFD);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementCompute, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_ref_D), // TensorD
decltype(tensor_ref_SFD), // TensorSfD
cute::Int<OutputSFVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params {alpha_host.at(i), beta_host.at(i), tensor_C, tensor_ref_D, tensor_ref_SFD, options.norm_constant};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.at(i).sync_host();
block_SFD.at(i).sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view());
passed &= cutlass::reference::host::TensorEquals(block_ref_SFD.at(i).host_view(), block_SFD.at(i).host_view());
// Check that the tensors have non-zero norms
passed &= (cutlass::reference::host::TensorNorm(block_ref_D.at(i).host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.at(i).host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_ref_SFD.at(i).host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_SFD.at(i).host_view()) > 0);
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options, bool host_problem_shapes_available = true)
{
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
if (options.verification) {
std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
else {
std::cout << " Verfication is turned off for this run." << std::endl;
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
}
timer.stop();
// Compute average setup and runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
if (__CUDACC_VER_MAJOR__ < 12 ||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
)
) {
std::cerr << "This example requires CUDA 12.8 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 12 && props.minor == 0)) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120a).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
allocate(options);
initialize(options);
//
// Evaluate CUTLASS kernels
//
std::cout << "Running kernel with Cooperative kernel schedule:" << std::endl;
run<Gemm>(options, false /*host_problem_shapes_available*/);
std::cout << "Running kernel with Pingpong kernel schedule:" << std::endl;
run<GemmPingpong>(options, false /*host_problem_shapes_available*/);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,83 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=50 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=51 --iterations=0) # Fixed problem sizes
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) # Small problem sizes
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
cutlass_example_add_executable(
79a_blackwell_geforce_nvfp4_bf16_gemm
79a_blackwell_geforce_nvfp4_bf16_gemm.cu
)
cutlass_example_add_executable(
79b_blackwell_geforce_nvfp4_nvfp4_gemm
79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu
)
cutlass_example_add_executable(
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm
79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu
)
cutlass_example_add_executable(
79d_blackwell_geforce_nvfp4_grouped_gemm
79d_blackwell_geforce_nvfp4_grouped_gemm.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP
)
endif()

View File

@ -0,0 +1,554 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions:
* mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.
Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution.
The kernel leverages:
1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Sparse Tensor Core MMA Instructions
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 16; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// E matrix configuration. Note, E is used to represent metadata tensor.
using ElementE = uint8_t; // Element type for E matrix operand
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120; // Kernel schedule policy
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelScheduleType // Mainloop schedule policy
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutE layout_E;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A_Decompressed;
cutlass::HostTensor<ElementE, cutlass::layout::PackedVectorLayout> block_E;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "80a_blackwell_geforce_mxfp8_bf16_sparse_gemm\n\n"
<< " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize blocks that released to sparse Matrix A and its metadata E
bool initialize_sparse_blocks(const Options &options) {
auto workload = make_shape(options.m,
options.n,
options.k,
1);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
/// Alias SparseConfig and Compressor
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig,
cutlass::arch::Sm120>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
/// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs
CompressorUtility compressor_utility(workload, stride_A);
// Aligned M K dimension size for A and E
int aligned_m_e = compressor_utility.get_metadata_m_physical();
int aligned_k_e = compressor_utility.get_metadata_k_physical();
int aligned_m_a = compressor_utility.get_tensorA_m_physical();
int aligned_k_a = compressor_utility.get_tensorA_k_physical();
/// Layout A and E
layout_A = SparseConfig::fill_layoutA(workload);
layout_E = SparseConfig::fill_layoutE(workload);
block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a));
block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e));
block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k));
initialize_block(block_A_Decompressed.host_view(), seed + 2020);
compressor_utility.structure_sparse_zero_mask_fill(
block_A_Decompressed.host_data(), static_cast<int>(seed + 2021));
block_A_Decompressed.sync_device();
/// Use compressor kernel to generate compressed Matrix A and E
cutlass::Status status { cutlass::Status::kSuccess };
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
{options.m, options.n, options.k, 1},
{block_A_Decompressed.device_data(),
stride_A,
block_A.device_data(),
block_E.device_data()},
{hw_info}
};
// Compress A and E
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
block_A.sync_host();
block_E.sync_host();
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
using namespace cute;
// Initial A, E(metadata) and A_compressed blocks
if(!initialize_sparse_blocks(options)) return false;
// Define B, C and D blocks
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
// Define SFA and SFB tensors layouts
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
return true;
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), layout_A,
block_B.device_data(), stride_B,
block_E.device_data(), layout_E,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
// Initialization
if(!initialize(options))
{
std::cerr << " Initialization failed! " << std::endl;
exit(-1);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 120.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,578 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture.
This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions:
* mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.
Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution.
The kernel leverages:
1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Block Scaled Sparse Tensor Core MMA Instructions
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
Usage:
$ ./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::ColumnMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::ColumnMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int outputVectorSize = 32; // Vector size for D matrix
using outputScaleFactor = cutlass::float_ue4m3_t; // Scale factor type for D matrix
// E matrix configuration. Note, E is used to represent metadata tensor.
using ElementE = uint8_t; // Element type for E matrix operand
// Kernel functional config
using ElementCompute = float; // Element type for computation inside mainloop and epilogue
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120; // Kernel schedule policy
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120, // Epilogue schedule policy
cutlass::epilogue::fusion::LinCombBlockScaleFactor< // Epilogue fusion to generate nvfp4 output
outputVectorSize, ElementD, ElementAccumulator, outputScaleFactor, LayoutDTag, ElementC>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelScheduleType // Mainloop schedule policy
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<outputVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutSFD layout_SFD;
LayoutE layout_E;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A_Decompressed;
cutlass::HostTensor<ElementE, cutlass::layout::PackedVectorLayout> block_E;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_SFD;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_reference_SFD;
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm\n\n"
<< " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize blocks that released to sparse Matrix A and its metadata E
bool initialize_sparse_blocks(const Options &options) {
auto workload = make_shape(options.m,
options.n,
options.k,
1);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
/// Alias SparseConfig and Compressor
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
cute::Shape<int, int, int, int>,
ElementA::DataType,
LayoutATag,
SparseConfig,
cutlass::arch::Sm120>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
/// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs
CompressorUtility compressor_utility(workload, stride_A);
// Aligned M K dimension size for A and E
int aligned_m_e = compressor_utility.get_metadata_m_physical();
int aligned_k_e = compressor_utility.get_metadata_k_physical();
int aligned_m_a = compressor_utility.get_tensorA_m_physical();
int aligned_k_a = compressor_utility.get_tensorA_k_physical();
/// Layout A and E
layout_A = SparseConfig::fill_layoutA(workload);
layout_E = SparseConfig::fill_layoutE(workload);
block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a));
block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e));
block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k));
initialize_block(block_A_Decompressed.host_view(), seed + 2020);
compressor_utility.structure_sparse_zero_mask_fill(
block_A_Decompressed.host_data(), static_cast<int>(seed + 2021));
block_A_Decompressed.sync_device();
/// Use compressor kernel to generate compressed Matrix A and E
cutlass::Status status { cutlass::Status::kSuccess };
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
{options.m, options.n, options.k, 1},
{block_A_Decompressed.device_data(),
stride_A,
block_A.device_data(),
block_E.device_data()},
{hw_info}
};
// Compress A and E
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = compressor_op.run();
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
return false;
}
block_A.sync_host();
block_E.sync_host();
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(const Options &options) {
using namespace cute;
// Initial A, E(metadata) and A_compressed blocks
if(!initialize_sparse_blocks(options)) return false;
// Define B, C and D blocks
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
// Define SFA and SFB tensors layouts
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_Normconst.reset(cutlass::make_Coord(1));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_Normconst.at(cutlass::make_Coord(0)) = 2;
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
block_SFD.sync_device();
block_Normconst.sync_device();
return true;
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), layout_A,
block_B.device_data(), stride_B,
block_E.device_data(), layout_E,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
auto tensor_SFD = cute::make_tensor(block_reference_SFD.host_data(), layout_SFD);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D), // TensorD
decltype(tensor_SFD), // TensorSfD
cute::Int<outputVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
// Initialization
if(!initialize(options))
{
std::cerr << " Initialization failed! " << std::endl;
exit(-1);
}
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 120.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,41 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
cutlass_example_add_executable(
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu
)
cutlass_example_add_executable(
80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm
80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu
)
endif()

View File

@ -0,0 +1,581 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_128,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutC, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelScheduleSm100Blockwise
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
// Strides just iterate over scalars and have no zeros
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
// Layouts are tiled to the problem size and the strides have zeros
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "81_blackwell_gemm_blockwise\n\n"
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "112_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -8;
scope_max = 8;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
tensor_SFA.resize(blockscale_a_coord);
tensor_SFB.resize(blockscale_b_coord);
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A,
tensor_B.device_data(), stride_B,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least sm100a.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,587 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_256,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_1,_1>;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a
// 16B alignment if possible (i.e., we have at least 16B of scales in M).
// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B
// and transpose A, B, C, and scales since B^T A^T = C^T.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutC, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelScheduleSm100Blockwise
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
// Strides just iterate over scalars and have no zeros
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
// Layouts are tiled to the problem size and the strides have zeros
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "81_blackwell_gemm_groupwise\n\n"
<< " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "81_blackwell_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -8;
scope_max = 8;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
tensor_SFA.resize(blockscale_a_coord);
tensor_SFB.resize(blockscale_b_coord);
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A,
tensor_B.device_data(), stride_B,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least sm100a.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,754 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 blockwise-scaled grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
In this example M, N, and K are fixed across groups.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float;
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_128,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
// Shape of the tile computed by each SM
using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC *, AlignmentC,
ElementD, LayoutC *, AlignmentD,
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA, LayoutSFA>);
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB, LayoutSFB>);
/// Initialization
uint64_t seed;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_SFA;
std::vector<int64_t> offset_SFB;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<ElementD *> ptr_ref_D_host;
std::vector<ElementA *> ptr_A_host;
std::vector<ElementB *> ptr_B_host;
std::vector<ElementC *> ptr_C_host;
std::vector<ElementD *> ptr_D_host;
std::vector<ElementAccumulator *> ptr_SFA_host;
std::vector<ElementAccumulator *> ptr_SFB_host;
// Shared Allocations
cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_ref_D;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> block_SFB;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<const ElementAccumulator *> ptr_SFA;
cutlass::DeviceAllocation<const ElementAccumulator *> ptr_SFB;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 2048, k = 512, groups = 10;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
for (int i = 0; i < groups; ++i) {
problem_sizes_host.push_back({m, n, k});
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "81_blackwell_grouped_gemm_blockwise\n\n"
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "81_blackwell_grouped_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * groups;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {
int32_t total_elements_A = 0;
int32_t total_elements_B = 0;
int32_t total_elements_C = 0;
int32_t total_elements_D = 0;
int32_t total_elements_SFA = 0;
int32_t total_elements_SFB = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_SFA.push_back(total_elements_SFA);
offset_SFB.push_back(total_elements_SFB);
int32_t elements_A = M * K;
int32_t elements_B = K * N;
int32_t elements_C = M * N;
int32_t elements_D = M * N;
auto gemm_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto gemm_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
int32_t elements_SFA = cosize(gemm_layout_SFA);
int32_t elements_SFB = cosize(gemm_layout_SFB);
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_SFA += elements_SFA;
total_elements_SFB += elements_SFB;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
layout_SFA_host.push_back(gemm_layout_SFA);
layout_SFB_host.push_back(gemm_layout_SFB);
}
block_A.resize(cutlass::make_Coord(total_elements_A));
block_B.resize(cutlass::make_Coord(total_elements_B));
block_C.resize(cutlass::make_Coord(total_elements_C));
block_D.resize(cutlass::make_Coord(total_elements_D));
block_ref_D.resize(cutlass::make_Coord(total_elements_D));
block_SFA.resize(cutlass::make_Coord(total_elements_SFA));
block_SFB.resize(cutlass::make_Coord(total_elements_SFB));
initialize_tensor(block_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(block_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(block_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_scale_tensor(block_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2026);
initialize_scale_tensor(block_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2027);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
// copy problem sizes
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
std::vector<ElementA *> device_ptr_A_host(options.groups);
std::vector<ElementB *> device_ptr_B_host(options.groups);
std::vector<ElementC *> device_ptr_C_host(options.groups);
std::vector<ElementD *> device_ptr_D_host(options.groups);
std::vector<ElementAccumulator *> device_ptr_SFA_host(options.groups);
std::vector<ElementAccumulator *> device_ptr_SFB_host(options.groups);
ptr_A_host = std::vector<ElementA *>(options.groups);
ptr_B_host = std::vector<ElementB *>(options.groups);
ptr_C_host = std::vector<ElementC *>(options.groups);
ptr_D_host = std::vector<ElementD *>(options.groups);
ptr_SFA_host = std::vector<ElementAccumulator *>(options.groups);
ptr_SFB_host = std::vector<ElementAccumulator *>(options.groups);
ptr_ref_D_host = std::vector<ElementD *>(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
// Ptrs for A
ptr_A_host.at(i) = block_A.host_data() + offset_A.at(i);
device_ptr_A_host.at(i) = block_A.device_data() + offset_A.at(i);
// Ptrs for B
ptr_B_host.at(i) = block_B.host_data() + offset_B.at(i);
device_ptr_B_host.at(i) = block_B.device_data() + offset_B.at(i);
// Ptrs for C
ptr_C_host.at(i) = block_C.host_data() + offset_C.at(i);
device_ptr_C_host.at(i) = block_C.device_data() + offset_C.at(i);
// Ptrs for D
ptr_D_host.at(i) = block_D.host_data() + offset_D.at(i);
device_ptr_D_host.at(i) = block_D.device_data() + offset_D.at(i);
ptr_ref_D_host.at(i) = block_ref_D.host_data() + offset_D.at(i);
// Ptrs for SFA
ptr_SFA_host.at(i) = block_SFA.host_data() + offset_SFA.at(i);
device_ptr_SFA_host.at(i) = block_SFA.device_data() + offset_SFA.at(i);
// Ptrs for SFB
ptr_SFB_host.at(i) = block_SFB.host_data() + offset_SFB.at(i);
device_ptr_SFB_host.at(i) = block_SFB.device_data() + offset_SFB.at(i);
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(device_ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(device_ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(device_ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(device_ptr_D_host.data());
ptr_SFA.reset(options.groups);
ptr_SFA.copy_from_host(device_ptr_SFA_host.data());
ptr_SFB.reset(options.groups);
ptr_SFB.copy_from_host(device_ptr_SFB_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
layout_SFA.reset(options.groups);
layout_SFA.copy_from_host(layout_SFA_host.data());
layout_SFB.reset(options.groups);
layout_SFB.copy_from_host(layout_SFB_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(),
ptr_B.get(), stride_B.get(),
ptr_SFA.get(), layout_SFA.get(),
ptr_SFB.get(), layout_SFB.get()
},
{
{}, // epilogue.thread
ptr_C.get(), stride_C.get(),
ptr_D.get(), stride_D.get()
},
hw_info
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
block_D.sync_host();
for (int i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(ptr_A_host.at(i),
cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i)));
auto B = cute::make_tensor(ptr_B_host.at(i),
cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i)));
auto C = cute::make_tensor(ptr_C_host.at(i),
cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i)));
auto D = cute::make_tensor(ptr_ref_D_host.at(i),
cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i)));
auto SFA = cute::make_tensor(ptr_SFA_host.at(i), layout_SFA_host.at(i));
auto SFB = cute::make_tensor(ptr_SFB_host.at(i), layout_SFB_host.at(i));
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
bool passed = cutlass::reference::host::TensorEquals(block_ref_D.host_view(), block_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.groups << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least sm100a.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,761 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 blockwise-scaled grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
In this example M, N, and K are fixed across groups.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float;
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_256,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_1,_1>;
// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a
// 16B alignment if possible (i.e., we have at least 16B of scales in M).
// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B
// and transpose A, B, C, and scales since B^T A^T = C^T.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC *, AlignmentC,
ElementD, LayoutC *, AlignmentD,
cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA, LayoutSFA>);
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB, LayoutSFB>);
/// Initialization
uint64_t seed;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_SFA;
std::vector<int64_t> offset_SFB;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<ElementD *> ptr_ref_D_host;
std::vector<ElementA *> ptr_A_host;
std::vector<ElementB *> ptr_B_host;
std::vector<ElementC *> ptr_C_host;
std::vector<ElementD *> ptr_D_host;
std::vector<ElementAccumulator *> ptr_SFA_host;
std::vector<ElementAccumulator *> ptr_SFB_host;
// Shared Allocations
cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_ref_D;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> block_SFB;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<const ElementAccumulator *> ptr_SFA;
cutlass::DeviceAllocation<const ElementAccumulator *> ptr_SFB;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 2048, k = 512, groups = 10;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
for (int i = 0; i < groups; ++i) {
problem_sizes_host.push_back({m, n, k});
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "81_blackwell_grouped_gemm_groupwise\n\n"
<< " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "81_blackwell_grouped_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * groups;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
int32_t total_elements_A = 0;
int32_t total_elements_B = 0;
int32_t total_elements_C = 0;
int32_t total_elements_D = 0;
int32_t total_elements_SFA = 0;
int32_t total_elements_SFB = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_SFA.push_back(total_elements_SFA);
offset_SFB.push_back(total_elements_SFB);
int32_t elements_A = M * K;
int32_t elements_B = K * N;
int32_t elements_C = M * N;
int32_t elements_D = M * N;
auto gemm_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto gemm_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
int32_t elements_SFA = cosize(gemm_layout_SFA);
int32_t elements_SFB = cosize(gemm_layout_SFB);
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_SFA += elements_SFA;
total_elements_SFB += elements_SFB;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
layout_SFA_host.push_back(gemm_layout_SFA);
layout_SFB_host.push_back(gemm_layout_SFB);
}
block_A.resize(cutlass::make_Coord(total_elements_A));
block_B.resize(cutlass::make_Coord(total_elements_B));
block_C.resize(cutlass::make_Coord(total_elements_C));
block_D.resize(cutlass::make_Coord(total_elements_D));
block_ref_D.resize(cutlass::make_Coord(total_elements_D));
block_SFA.resize(cutlass::make_Coord(total_elements_SFA));
block_SFB.resize(cutlass::make_Coord(total_elements_SFB));
initialize_tensor(block_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(block_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(block_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_scale_tensor(block_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2026);
initialize_scale_tensor(block_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2027);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
// copy problem sizes
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
std::vector<ElementA *> device_ptr_A_host(options.groups);
std::vector<ElementB *> device_ptr_B_host(options.groups);
std::vector<ElementC *> device_ptr_C_host(options.groups);
std::vector<ElementD *> device_ptr_D_host(options.groups);
std::vector<ElementAccumulator *> device_ptr_SFA_host(options.groups);
std::vector<ElementAccumulator *> device_ptr_SFB_host(options.groups);
ptr_A_host = std::vector<ElementA *>(options.groups);
ptr_B_host = std::vector<ElementB *>(options.groups);
ptr_C_host = std::vector<ElementC *>(options.groups);
ptr_D_host = std::vector<ElementD *>(options.groups);
ptr_SFA_host = std::vector<ElementAccumulator *>(options.groups);
ptr_SFB_host = std::vector<ElementAccumulator *>(options.groups);
ptr_ref_D_host = std::vector<ElementD *>(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
// Ptrs for A
ptr_A_host.at(i) = block_A.host_data() + offset_A.at(i);
device_ptr_A_host.at(i) = block_A.device_data() + offset_A.at(i);
// Ptrs for B
ptr_B_host.at(i) = block_B.host_data() + offset_B.at(i);
device_ptr_B_host.at(i) = block_B.device_data() + offset_B.at(i);
// Ptrs for C
ptr_C_host.at(i) = block_C.host_data() + offset_C.at(i);
device_ptr_C_host.at(i) = block_C.device_data() + offset_C.at(i);
// Ptrs for D
ptr_D_host.at(i) = block_D.host_data() + offset_D.at(i);
device_ptr_D_host.at(i) = block_D.device_data() + offset_D.at(i);
ptr_ref_D_host.at(i) = block_ref_D.host_data() + offset_D.at(i);
// Ptrs for SFA
ptr_SFA_host.at(i) = block_SFA.host_data() + offset_SFA.at(i);
device_ptr_SFA_host.at(i) = block_SFA.device_data() + offset_SFA.at(i);
// Ptrs for SFB
ptr_SFB_host.at(i) = block_SFB.host_data() + offset_SFB.at(i);
device_ptr_SFB_host.at(i) = block_SFB.device_data() + offset_SFB.at(i);
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(device_ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(device_ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(device_ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(device_ptr_D_host.data());
ptr_SFA.reset(options.groups);
ptr_SFA.copy_from_host(device_ptr_SFA_host.data());
ptr_SFB.reset(options.groups);
ptr_SFB.copy_from_host(device_ptr_SFB_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
layout_SFA.reset(options.groups);
layout_SFA.copy_from_host(layout_SFA_host.data());
layout_SFB.reset(options.groups);
layout_SFB.copy_from_host(layout_SFB_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(),
ptr_B.get(), stride_B.get(),
ptr_SFA.get(), layout_SFA.get(),
ptr_SFB.get(), layout_SFB.get()
},
{
{}, // epilogue.thread
ptr_C.get(), stride_C.get(),
ptr_D.get(), stride_D.get()
},
hw_info
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
block_D.sync_host();
for (int i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(ptr_A_host.at(i),
cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i)));
auto B = cute::make_tensor(ptr_B_host.at(i),
cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i)));
auto C = cute::make_tensor(ptr_C_host.at(i),
cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i)));
auto D = cute::make_tensor(ptr_ref_D_host.at(i),
cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i)));
auto SFA = cute::make_tensor(ptr_SFA_host.at(i), layout_SFA_host.at(i));
auto SFB = cute::make_tensor(ptr_SFB_host.at(i), layout_SFB_host.at(i));
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
bool passed = cutlass::reference::host::TensorEquals(block_ref_D.host_view(), block_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.groups << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least sm100a.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,71 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_SMALL --m=256 --n=128 --k=128 --iterations=0) # Small problem sizes
cutlass_example_add_executable(
81_blackwell_gemm_blockwise
81_blackwell_gemm_blockwise.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_EPILOGUE
TEST_SMALL
)
cutlass_example_add_executable(
81_blackwell_gemm_groupwise
81_blackwell_gemm_groupwise.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_EPILOGUE
TEST_SMALL
)
cutlass_example_add_executable(
81_blackwell_grouped_gemm_blockwise
81_blackwell_grouped_gemm_blockwise.cu
TEST_COMMAND_OPTIONS
TEST_SMALL
)
cutlass_example_add_executable(
81_blackwell_grouped_gemm_groupwise
81_blackwell_grouped_gemm_groupwise.cu
TEST_COMMAND_OPTIONS
TEST_SMALL
)
endif()

View File

@ -0,0 +1,873 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Distributed GEMM (DistGEMM) for Blackwell.
This example runs Tensor Parallel GEMMs using the (experimental) Distributed GEMM API in
CUTLASS. For more information, please refer to README.md.
Note that Distributed GEMM assumes an any-to-any NVLink network topology.
To check whether your device is compatible, run:
$ nvidia-smi topo -m
and make sure there's an any-to-any NVLink topology. It would look like this:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X
You should also additionally check if the driver enables peer to peer access:
$ nvidia-smi topo -p2p r
Output should be something like this:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X OK OK OK OK OK OK OK
GPU1 OK X OK OK OK OK OK OK
GPU2 OK OK X OK OK OK OK OK
GPU3 OK OK OK X OK OK OK OK
GPU4 OK OK OK OK X OK OK OK
GPU5 OK OK OK OK OK X OK OK
GPU6 OK OK OK OK OK OK X OK
GPU7 OK OK OK OK OK OK OK X
It is recommended to build this target with the following flag to enable
Grid Dependency Control instructions (GDC) in CUTLASS:
- CUTLASS_ENABLE_GDC_FOR_SM100
Example:
$ mkdir build && cd build
$ cmake .. -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
$ cd examples/82_blackwell_distributed_gemm
$ make
$ ./82_blackwell_distributed_gemm
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
// Distributed GEMM headers
#include "cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp"
#include "cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp"
#include "cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp"
#include "helper.h"
// Distributed GEMM helpers
#include "dist_gemm_helpers.h"
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Distributed GEMM configuration
/////////////////////////////////////////////////////////////////////////////////////////////////
// TP size (= number of processors/GPUs)
using TP = _8;
static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
// Distributed GEMM tiling/sharding schedule
// Choices:
//
// * All Gather + GEMM:
// * AllGather1D_TilingCD_RotatingA
// * AllGather1D_TilingCD_RotatingB
//
// * GEMM + Reduce Scatter:
// * ReduceScatter1D_TilingA_RotatingC
// * ReduceScatter1D_TilingB_RotatingC
using DistSchedule = cutlass::distributed::schedules::AllGather1D_TilingCD_RotatingA<TP>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
using ElementD = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutD = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_256,_256,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_1,_1>;
// Shape of the tile computed by each SM
using PerSmTileShape_MNK = Shape<_128, _256, _128>;
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
PerSmTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
// Build the mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100
>::CollectiveOp;
// Compose into a kernel
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
// We're going to use the single-device GEMM as reference
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Instantiate Distributed GEMM kernel
using DistGemmKernel = cutlass::distributed::kernel::DistributedGemmKernelWrapper<
GemmKernel,
DistSchedule
>;
using DistGemm = cutlass::distributed::device::DistributedGemmUniversalAdapter<DistGemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
using HostTensorA = typename cutlass::HostTensor<ElementA, LayoutA>;
using HostTensorB = typename cutlass::HostTensor<ElementB, LayoutB>;
using HostTensorC = typename cutlass::HostTensor<ElementC, LayoutC>;
using HostTensorD = typename cutlass::HostTensor<ElementD, LayoutD>;
// Reference GEMM tensors
HostTensorA tensor_A;
HostTensorB tensor_B;
HostTensorC tensor_C;
HostTensorD tensor_D;
HostTensorD tensor_ref_D;
// DistGEMM tensors (multi-device)
HostTensorA tensor_A_arr[TP_];
HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.f, beta = 0.f;
int iterations = 100;
int warmup_iterations = 10;
int m = 16384, n = 106496, k = 16384, l = 1;
float eps = 0.f;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("warmup-iterations", warmup_iterations);
cmd.get_cmd_line_argument("eps", eps);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "82_blackwell_distributed_gemm\n\n"
<< " Blackwell Distributed GEMM (DistGEMM). \n"
<< " For more details please refer to the source file.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch) of the GEMM (default: 1)\n"
<< " --alpha=<f32> Epilogue scalar alpha (default: 1.0)\n"
<< " --beta=<f32> Epilogue scalar beta (default: 0.0)\n"
<< " --iterations=<int> Number of profiling iterations to perform (default: 100)\n"
<< " --warmup-iterations=<int> Number of warmup iterations prior to profiling (default: 10)\n"
<< " --eps=<f32> Threshold for error compared to reference "
<< "GEMM (default: 0.0)\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "82_blackwell_distributed_gemm" << " --m=16384 --n=106496 --k=16384 \n\n";
return out;
}
/// Compute performance in TFLOP/s
double tflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l / TP_;
double tflop = double(flop) / double(1.0e12);
return tflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double tflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double tflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), tflops(tflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed,
bool is_device_tensor = false) {
double scope_max, scope_min;
int bits = cutlass::sizeof_bits<Element>::value;
if (bits == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits <= 16) {
scope_max = 2;
scope_min = -2;
}
else {
scope_max = 8;
scope_min = -8;
}
if (is_device_tensor) {
using Real = typename cutlass::RealType<Element>::Type;
cutlass::reference::device::TensorFillRandomUniform(
view, seed, static_cast<Real>(scope_max), static_cast<Real>(scope_min), 0);
cudaDeviceSynchronize();
} else {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
// Setup (reference) GEMM tensors
auto shape_A = cute::select<0,2,3>(problem_shape);
auto shape_B = cute::select<1,2,3>(problem_shape);
auto shape_C = cute::select<0,1,3>(problem_shape);
auto shape_D = cute::select<0,1,3>(problem_shape);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A);
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C);
stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D);
auto a_coord = cutlass::make_Coord(size(shape_A), 1);
auto b_coord = cutlass::make_Coord(size(shape_B), 1);
auto c_coord = cutlass::make_Coord(size(shape_C), 1);
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
initialize_tensor(tensor_A.device_view(), seed + 2022, /* is_device_tensor = */ true);
initialize_tensor(tensor_B.device_view(), seed + 2023, /* is_device_tensor = */ true);
initialize_tensor(tensor_C.device_view(), seed + 2024, /* is_device_tensor = */ true);
tensor_A.sync_host();
tensor_B.sync_host();
tensor_C.sync_host();
tensor_D.sync_host();
tensor_ref_D.sync_host();
// Set up DistGEMM tensors
auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape);
auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape);
auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape);
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
auto a_coord_device = cutlass::make_Coord(size(local_shape_A), 1);
auto b_coord_device = cutlass::make_Coord(size(local_shape_B), 1);
auto c_coord_device = cutlass::make_Coord(size(local_shape_C), 1);
int primary_device_idx;
CUDA_CHECK(cudaGetDevice(&primary_device_idx));
// Enable any-to-any access
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
int can_access;
CUDA_CHECK(cudaSetDevice(device_idx));
for (int peer_idx = 0; peer_idx < TP_; ++peer_idx) {
if (peer_idx != device_idx) {
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device_idx, peer_idx));
if (not can_access) {
std::cerr << "FAILURE: Device " << device_idx << " can't access device " << peer_idx << "." <<
std::endl;
exit(EXIT_FAILURE);
}
CUDA_CHECK(cudaDeviceEnablePeerAccess(peer_idx, 0));
}
}
tensor_A_arr[device_idx].resize(a_coord_device);
tensor_B_arr[device_idx].resize(b_coord_device);
tensor_C_arr[device_idx].resize(c_coord_device);
tensor_D_arr[device_idx].resize(c_coord_device);
}
CUDA_CHECK(cudaSetDevice(primary_device_idx));
}
/// Commandline options -> Gemm/DistGemm Arguments
using GemmArguments = typename Gemm::Arguments;
GemmArguments gemm_args_from_options(const Options &options) {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
{
{static_cast<ElementCompute>(options.alpha), static_cast<ElementCompute>(options.beta)},
tensor_C.device_data(), stride_C,
tensor_ref_D.device_data(), stride_D
}
};
return arguments;
}
using DistGemmArguments = typename DistGemm::Arguments;
DistGemmArguments dist_gemm_args_from_options(
const Options &options,
int device_idx,
cudaStream_t stream) {
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
auto global_A = cute::make_tensor(tensor_A.device_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto global_B = cute::make_tensor(tensor_B.device_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto global_C = cute::make_tensor(tensor_C.device_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto global_A_device_slice = DistSchedule::get_device_slice_A(global_A, device_idx);
auto global_B_device_slice = DistSchedule::get_device_slice_B(global_B, device_idx);
auto global_C_device_slice = DistSchedule::get_device_slice_C(global_C, device_idx);
auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape);
auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape);
auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape);
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
auto local_stride_A = cutlass::make_cute_packed_stride(StrideA{}, local_shape_A);
auto local_stride_B = cutlass::make_cute_packed_stride(StrideB{}, local_shape_B);
auto local_stride_C = cutlass::make_cute_packed_stride(StrideC{}, local_shape_C);
auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D);
auto local_A = cute::make_tensor(
tensor_A_arr[device_idx].device_data(),
make_layout(local_shape_A, local_stride_A));
auto local_B = cute::make_tensor(
tensor_B_arr[device_idx].device_data(),
make_layout(local_shape_B, local_stride_B));
auto local_C = cute::make_tensor(
tensor_C_arr[device_idx].device_data(),
make_layout(local_shape_C, local_stride_C));
auto local_D = cute::make_tensor(
tensor_D_arr[device_idx].device_data(),
make_layout(local_shape_D, local_stride_D));
// Copy over tensor tiles for the first iteration
cutlass::device_copy(global_A_device_slice, local_A, stream);
cutlass::device_copy(global_B_device_slice, local_B, stream);
cutlass::device_copy(global_C_device_slice, local_C, stream);
DistGemmArguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm, // mode
problem_shape, // problem shape
{
reinterpret_cast<const ElementA*>(local_A.data()),
local_A.stride(),
reinterpret_cast<const ElementB*>(local_B.data()),
local_B.stride()
}, // mainloop
{
{ // epilogue.thread
static_cast<ElementCompute>(options.alpha),
static_cast<ElementCompute>(options.beta)
},
reinterpret_cast<const ElementC*>(local_C.data()),
local_C.stride(),
reinterpret_cast<ElementD*>(local_D.data()),
local_D.stride(),
}, // epilogue
{}, // hw_info
{} // scheduler
};
return arguments;
}
// Gathers results, moves back to the original full-sized D tensor on the primary device.
void gather_results(const Options &options, int device_idx, cudaStream_t stream = nullptr) {
auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l);
// Global dest
auto global_D = cute::make_tensor(tensor_D.device_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto global_D_device_slice = DistSchedule::get_device_slice_D(global_D, device_idx);
// Device_idx local dest
auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape);
auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D);
auto local_D = cute::make_tensor(
tensor_D_arr[device_idx].device_data(),
make_layout(local_shape_D, local_stride_D)
);
// Copy to global dest
cutlass::device_copy(local_D, global_D_device_slice, stream);
}
bool verify(const Options &options) {
tensor_D.sync_host();
tensor_ref_D.sync_host();
bool passed = false;
if (options.eps == 0.f) {
passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
} else {
double err = cutlass::reference::host::TensorRelativeErrorMetric(
tensor_D.host_view(),
tensor_ref_D.host_view());
passed = err < 1e-5;
}
if (options.m <= 64 && options.n <= 64) {
std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n";
std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n";
}
return passed;
}
/// Execute a given example GEMM computation
int run(Options &options) {
int primary_device_idx;
cudaError_t device_get_result = cudaGetDevice(&primary_device_idx);
if (device_get_result != cudaSuccess) {
throw std::runtime_error("cudaGetDevice() failed");
}
initialize(options);
// Reference single-GPU GEMM
Gemm reference_gemm;
cutlass::device_memory::allocation<uint8_t> reference_workspace;
auto reference_arguments = gemm_args_from_options(options);
size_t reference_workspace_size = Gemm::get_workspace_size(reference_arguments);
reference_workspace = cutlass::device_memory::allocation<uint8_t>(reference_workspace_size);
CUTLASS_CHECK(reference_gemm.can_implement(reference_arguments));
CUTLASS_CHECK(reference_gemm.initialize(reference_arguments, reference_workspace.get()));
CUTLASS_CHECK(reference_gemm.run());
using ElementBarrier = typename DistGemm::ElementBarrier;
using ElementFlag = typename DistGemmKernel::ElementFlag;
// Set up per-device streams
cudaStream_t stream_arr[TP_];
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
// Create stream
CUDA_CHECK(cudaStreamCreate(&stream_arr[device_idx]));
}
// Instantiate DistGEMM
DistGemm dist_gemm_arr[TP_]; // Distributed GEMM array for multiple devices
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace_arr[TP_];
cutlass::device_memory::allocation<uint8_t> exclusive_workspace_arr[TP_];
// Cross-device workspace pointer array for gemm.initialize()
void * workspace_ptr_arr[TP_];
void * exclusive_workspace_ptr_arr[TP_];
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
DistGemmArguments arguments_[TP_];
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]);
size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size();
workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(workspace_size);
exclusive_workspace_arr[device_idx] = cutlass::device_memory::allocation<uint8_t>(exclusive_workspace_size);
// Throw workspace pointers into arrays for gemm.initialize()
workspace_ptr_arr[device_idx] = workspace_arr[device_idx].get();
exclusive_workspace_ptr_arr[device_idx] = exclusive_workspace_arr[device_idx].get();
// Zero out exclusive workspace
cudaMemsetAsync(exclusive_workspace_ptr_arr[device_idx], 0, exclusive_workspace_size, stream_arr[device_idx]);
cudaDeviceSynchronize();
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
// Check if the problem size is supported or not
CUTLASS_CHECK(dist_gemm_arr[device_idx].can_implement(arguments_[device_idx]));
#if defined(CUTLASS_ENABLE_GDC_FOR_SM100)
bool launch_with_pdl = true;
#else
bool launch_with_pdl = false;
#endif
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(dist_gemm_arr[device_idx].initialize(
arguments_,
workspace_ptr_arr,
exclusive_workspace_ptr_arr,
device_idx,
stream_arr[device_idx],
launch_with_pdl
));
cudaDeviceSynchronize();
}
// Correctness / Warmup iteration
std::cout << std::endl << " running DistGEMM..." << std::endl;
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx]));
CUDA_CHECK(cudaGetLastError());
gather_results(options, device_idx);
}
std::cout << " running DistGEMM finished without runtime errors" << std::endl;
//// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << std::endl << " Disposition (eps: " << options.eps << "): " <<
(result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0) {
float elapsed_ms = 0.f;
// Warmup
std::cout << " Warming up for " << options.warmup_iterations << " iterations." << std::endl;
for (int warmup_iter = 0; warmup_iter < options.warmup_iterations; ++warmup_iter) {
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
}
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx]));
}
CUDA_CHECK(cudaSetDevice(primary_device_idx));
// Benchmark
std::cout << " Profiling for " << options.iterations << " iterations." << std::endl;
using AtomicBoolean = cuda::atomic<bool>;
AtomicBoolean* atomic_flag_ptr;
CUDA_CHECK(cudaHostAlloc(&atomic_flag_ptr, sizeof(AtomicBoolean), cudaHostAllocPortable));
atomic_flag_ptr->store(false);
cutlass::DistGpuTimer<TP_> timer;
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
cutlass::delay_kernel<<<1, 1, 0, stream_arr[device_idx]>>>(atomic_flag_ptr);
CUDA_CHECK(cudaGetLastError());
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
timer.start(device_idx, stream_arr[device_idx]);
}
atomic_flag_ptr->store(true);
for (int profile_iter = 0; profile_iter < options.iterations; ++profile_iter) {
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx]));
}
}
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
CUDA_CHECK(cudaSetDevice(device_idx));
timer.stop(device_idx, stream_arr[device_idx]);
}
CUDA_CHECK(cudaSetDevice(primary_device_idx));
for (int device_idx = 0; device_idx < TP_; ++device_idx) {
elapsed_ms = max(elapsed_ms, timer.elapsed_millis(device_idx));
}
// Compute average runtime and TFLOPs.
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0);
result.tflops = options.tflops(avg_runtime_s);
auto [local_M, local_N, local_K, local_L] = DistSchedule::get_local_gemm_shape(
cute::make_tuple(options.m, options.n, options.k, options.l));
std::cout << std::endl;
std::cout << " TP: " << TP::value << std::endl;
std::cout << " Problem Size: " <<
options.m << " x " <<
options.n << " x " <<
options.k << " x " <<
options.l << std::endl;
std::cout << " Local GEMM Problem Size: " <<
local_M << " x " <<
local_N << " x " <<
local_K << " x " <<
local_L<< std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS: " << result.tflops << std::endl;
}
return 0;
}
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
int num_devices;
CUDA_CHECK(cudaGetDeviceCount(&num_devices));
if (num_devices < TP_) {
std::cerr << "Distributed GEMM is compiled with TP = " << TP::value << ", but " <<
"found only " << num_devices << " devices." <<
std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
<< "(compute capability 100), "
<< "got compute capability " << props.major * 10 + props.minor << "."
<< std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
run(options);
#else
std::cerr
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl;
return 0;
#endif
return 0;
}

Some files were not shown because too many files have changed in this diff Show More