Compare commits

..

39 Commits
v4.1.0 ... v4

Author SHA1 Message Date
57e3cfb47a doc change for 4.2 (#2639)
* doc change

* fix broken links

* ragged gemm doc update

* move around texts about moe gemm
2025-09-15 22:02:45 -04:00
e7e0adddac Update version.h
change version number to 4.2
2025-09-15 12:40:58 -04:00
6a35b4d22f v4.2 tag release. (#2638) 2025-09-15 12:21:53 -04:00
56f0718a97 ex77 backwards GQA (#2556)
* bwd GQA init

* Update examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu

* ref kernel type conversion fix

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-09-09 12:53:28 -04:00
76c96b0be3 Fix incorrect shapes in copy_atom doc comments. (#2575) 2025-09-04 16:57:24 -07:00
d98e7bf7ce Fix comment in mma_atom.hpp (#2579) 2025-09-04 16:56:39 -07:00
b6ccf34aef Fix Copy_Atom type mismatch in sgemm_sm80.cu (#2582) 2025-09-04 16:56:17 -07:00
2288c0c901 Fix bugs in matrix.h (#2598) 2025-09-04 16:55:11 -07:00
b2dd65dc86 more robust imports in heuristics.py and heuristics_provider.py (#2596) 2025-08-28 22:32:55 -04:00
496654bf2c Fix sm100 gemm wrong static constexpr that breaks compilation on Windows (#2167)
* Fix a sm100 gemm wrong defined static constexpr that breaks compilation on Windows

* Fix a sm100 gemm wrong defined static constexpr that breaks compilation on Windows

* More Windows fixes

Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>

* Revert "More Windows fixes"

This reverts commit 2e8cfc1382.

---------

Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>
2025-08-28 22:13:00 -04:00
9ca7e877b2 fix gqa issue for blackwell fmha.py (#2599) 2025-08-28 11:15:20 -04:00
a49a78ffef v4.2 release. (#2587)
* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line.

* v4.2 release.
2025-08-22 18:11:24 -04:00
11cad1f67b fix a typo. (#2561) 2025-08-19 22:23:09 -04:00
931359cec1 Fix typo in functional.h (#2571) 2025-08-19 22:22:31 -04:00
42e7c546c4 Add movmatrix support (movmatrix.sync.aligned.m8n8.trans.b16) (#2562) 2025-08-19 22:22:02 -04:00
ec18e8043b Make swizzle in pycute work (#2553) 2025-08-19 22:21:00 -04:00
5b76420d6a [DOC] Add more exposition to composition example (#2536)
* Add more exposition to composition example

* Apply suggestions from code review

Co-authored-by: Cris Cecka <ccecka@users.noreply.github.com>

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Cris Cecka <ccecka@users.noreply.github.com>
2025-08-11 22:20:36 -04:00
19772cd63e Fix typo in smem_allocator.py (#2517) 2025-08-10 22:44:22 -04:00
052afcd314 fix typo (#2529) 2025-08-10 22:44:02 -04:00
86cf63e2d4 NIT: Grammar (#2537) 2025-08-10 22:42:45 -04:00
a267d47f9b Update batched_gemm.cu (#2538) 2025-08-10 22:42:21 -04:00
9e6ab77d27 Fix a copy error in the SM70 main loop when loading data from smem to rmem (#2540) 2025-08-10 22:42:01 -04:00
d0eada85a3 Support both CUDA 12 and 13 cccl header locations (#2543) 2025-08-10 22:41:25 -04:00
23139309e9 Fix incorrect K dim in CuTe MMA Atom doc. (#2544) 2025-08-10 22:40:56 -04:00
6dd13d4278 Facebook:This commit makes its files safe for use with -Wimplicit-fallthrough. (#2324) 2025-07-31 20:55:19 -04:00
3b054767b3 Fix typo (#2514) 2025-07-30 22:14:54 -04:00
6fb5e667c1 [Doc fix] incorrect compute cap. for Blackwell RTX (#2511)
Blackwell RTX is compute capability 12.0 (SM120) but incorrectly listed
as SM100 in the README.
2025-07-30 22:14:13 -04:00
6c891db9f6 Fix epilogue:🧵:Convert cannot be used with cute::collective::DefaultEpilogue. (#2333) 2025-07-30 22:12:53 -04:00
da47886e34 Fix example bug (#2351) 2025-07-30 22:12:33 -04:00
26b7450023 support fp16 accmulator for sm89 fp8 mma (#2378)
* add support for sm89 in cute and the unit tests

* support fp16 accmulator for sm89 fp8 mma

* format code
2025-07-30 22:12:08 -04:00
a39cf6b511 Fix example in CuTe tutorials (#2416) 2025-07-30 22:11:47 -04:00
f09045d660 Corrected minor nit in mma_traits.hpp (#2447)
* Corrected minor nit in mma_traits.hpp

The entry and descriptions were jumbled up.

* Update mma_traits.hpp

* Update mma_traits.hpp
2025-07-30 22:11:23 -04:00
84a27b3926 fix: examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu GridDim miscalculated (#2492)
* fix: examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu Launch dimGrid error

* feat: add cta tiler

* Update examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu

use cluster_layout_vmnk instead of cta_tiler

Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>

* feat: remove cta_tiler

---------

Co-authored-by: qinghongzeng <qinghongzeng@deeproute.ai>
Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>
2025-07-30 22:11:04 -04:00
e093b4f691 Fix tutorial comment in sgemm_1.cu: use tCrC instead of tCsA in axpby explanation (#2448) 2025-07-30 22:09:55 -04:00
664c4f7b3e Update CUTLASS version to 4.1
Update CUTLASS version to 4.1.
2025-07-26 20:11:04 -04:00
0e026982ce Example 77 add blackwell fmha bwd for MLA shape (#2466)
* Update examples/77_blackwell_fmha/device/fmha_device_bwd.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

* bug fix & use existing value rather than pass one more argument to support different dim in bwd_convert

* Fix casual mask cnt when IsQBegin==false

* bug fix in casual mask backward

* code sync

---------

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
2025-07-24 18:41:11 -04:00
9a9a579714 Merge pull request #2489 from NVIDIA/update_workflow_script
Support "CuTe DSL" auto-labeling in workflow
2025-07-23 15:33:43 +08:00
51d730b8be Support "CuTe DSL" auto-labeling in workflow 2025-07-23 00:28:01 -07:00
6c0c8b7484 1. Update bug/feature report template to add component selection. (#2485)
2. Add workflow to apply component label automatically
2025-07-22 12:38:03 -04:00
483 changed files with 43726 additions and 6249 deletions

View File

@ -1,23 +0,0 @@
---
name: Bug report
about: Create a bug report to help us improve CUTLASS
title: "[BUG]"
labels: "? - Needs Triage, bug"
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps/Code to reproduce bug**
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment details (please complete the following information):**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
**Additional context**
Add any other context about the problem here.

38
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: Bug Report
description: Create a bug report to help us improve CUTLASS
title: "[BUG] "
labels: ["? - Needs Triage", "bug"]
assignees: []
body:
- type: dropdown
id: component
attributes:
label: Which component has the problem?
options:
- CuTe DSL
- CUTLASS C++
validations:
required: true
- type: textarea
id: bug-report
attributes:
label: Bug Report
description: Please fill out all sections below
value: |
**Describe the bug**
A clear and concise description of what the bug is.
**Steps/Code to reproduce bug**
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment details (please complete the following information):**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
**Additional context**
Add any other context about the problem here.
validations:
required: true

View File

@ -1,20 +0,0 @@
---
name: Feature request
about: Suggest an idea for CUTLASS
title: "[FEA]"
labels: "? - Needs Triage, feature request"
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context, code examples, or references to existing implementations about the feature request here.

View File

@ -0,0 +1,35 @@
name: Feature Request
description: Suggest an idea for CUTLASS
title: "[FEA] "
labels: ["? - Needs Triage", "feature request"]
assignees: []
body:
- type: dropdown
id: component
attributes:
label: Which component requires the feature?
options:
- CuTe DSL
- CUTLASS C++
validations:
required: true
- type: textarea
id: feature-request
attributes:
label: Feature Request
description: Please fill out all sections below
value: |
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context, code examples, or references to existing implementations about the feature request here.
validations:
required: true

51
.github/workflows/auto-label-issues.yml vendored Normal file
View File

@ -0,0 +1,51 @@
name: Auto Label Issues
on:
issues:
types: [opened]
jobs:
add-labels:
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Add component label
uses: actions/github-script@v7
with:
script: |
const issue = context.payload.issue;
const body = issue.body || '';
// Parse the issue body to find the component selection
// GitHub renders dropdown selections as "### {label}\n\n{selection}"
// Check for both bug report and feature request dropdown labels
const bugComponentMatch = body.match(/### Which component has the problem\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
const featureComponentMatch = body.match(/### Which component requires the feature\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
const componentMatch = bugComponentMatch || featureComponentMatch;
if (componentMatch) {
const component = componentMatch[1].trim();
let label = '';
// Map component selections to labels
switch(component) {
case 'CuTe DSL':
label = 'CuTe DSL';
break;
case 'CUTLASS C++':
label = 'CUTLASS C++';
break;
}
if (label) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issue.number,
labels: [label]
});
console.log(`Added label: ${label}`);
}
}

View File

@ -55,7 +55,7 @@ jobs:
if: |
(startsWith(github.event.comment.body, '/bot run') ||
startsWith(github.event.comment.body, '/bot kill')) && contains(
fromJson('["zekunf-nv"]'),
fromJson('["nv-fastkernels-cicd", "zekunf-nv", "hwu36", "IonThruster", "thakkarV", "d-k-b", "mihir-awatramani", "fengxie", "vickiw973", "Junkai-Wu", "brandon-yujie-sun", "lijingticy22", "hongw-nv", "vikgupta-nv", "IwakuraRein", "depaulmillz", "jackkosaian", "itramble", "ccecka", "sxtyzhangzk", "hbarclay", "yzhaiustc", "x86vk", "sklevtsov-nvidia", "ANIKET-SHIVAM", "Shreya-gaur", "azhurkevich", "serifyesil", "richardmcai", "lsyyy666", "Ethan-Yan27", "XiaoSong9905", "shdetect", "keithzzzzz"]'),
github.actor)
steps:
- name: Check if comment is issued by authorized person

View File

@ -2,6 +2,97 @@
# CUTLASS 4.x
## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15)
### CuTe DSL
* More Python versions are now supported for both x86-64 and aarch64, including
- Python 3.10, 3.11, 3.12, and 3.13
* Added new example and updated notebook to get started with CuTe DSL
- [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py)
- Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb)
+ Added a section for introducing the broadcast
* API updates
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
* Bug fixings and improvements
- Fixed ``cute.print_tensor`` for coordinate tensor
- Fixed `cute.print` for tuple of layouts
- Fixed frozen object is not properly updated after fully assigned in dynamic control flow
- Fixed assign tuple/list element in a dynamic control flow may cause compilation failure
- Improved error message when CUDA context is not initialized
- Improved docstring of congruent and weakly_congruent
### CUTLASS C++
* Support for Blackwell SM103 kernels for B300 GPUs.
- Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp)
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture:
- [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/).
- [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm).
* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM
- Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/).
* Support for Blackwell SM121 kernels for DGX Spark GPUs.
- Share the major codes with Blackwell SM120 kernels.
* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario.
- Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md).
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Add fused reduction kernel support for cutlass MLA.
- Add softmax skip correction.
- Support for GQA in FMHA backward kernel.
- Fix an issue where `get_unmasked_trip_count` may return a negative value.
- Fix an issue where mbarriers are initialized with a zero arrival count.
- Fix a corner case issue where the sequence length of q is not a multiple of tile_q.
- Remove tma padding for forward kernel inputs.
* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome.
* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell
- On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/).
- On Hopper, add K major scale factor support for SM90 blockwise kernels.
- On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size.
- On Hopper, grouped version supports the case when k = 0.
* Support for Blackwell SM100 fp4 gemv kernels.
- Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h).
- Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/)
* Support for Blackwell SM100 legacy mixed input GEMM kernels.
- Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp).
- Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp).
- Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/).
* Support for Blackwell SM100 cpasync kernel.
- Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp).
- Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp).
* Support Blackwell SM120 mixed input blockscaled grouped GEMM.
* Instantiating more Blackwell kernels in profiler.
- Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations.
- To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels.
- Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md).
* Fix some profiler issues:
- Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line.
- Fix some no output and timeout issues.
- Fix Pingpong Blockwise Hopper library generation.
* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110.
- For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs.
- For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid.
* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface.
- Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`.
- Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter.
- Added some support for running SM100 kernels via the Python interface.
* CuTe changes:
- Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/).
- Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support.
- Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels.
- Support fp16 accmulator for sm89 fp8 mma.
- Shorten `nullspace` implementation.
- Isolate and comment on `cosize` hacks.
- Important documentation correction: `E<0,1> == 1@0@1`.
* Fix some kernel issues:
- Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers.
- Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel.
* Add following unit tests:
- [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu)
- [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu)
- [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu)
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 13.0U1.
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
### CuTe DSL
@ -10,7 +101,7 @@
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
* API updates
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
### CUTLASS C++
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
@ -58,7 +149,7 @@
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
### CUTLASS C++
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9

View File

@ -175,13 +175,25 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a 121 121a)
if (CUDA_VERSION VERSION_LESS 13.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
else()
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
endif()
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f 121f 103a 103f)
if (CUDA_VERSION VERSION_LESS 13.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
else()
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110f)
endif()
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
@ -288,17 +300,49 @@ if (KERNEL_FILTER_FILE)
set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE)
endif()
if (CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
get_filename_component(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" ABSOLUTE)
set(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" CACHE STRING "HEURISTICS FILE FULL PATH" FORCE)
endif()
set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list")
if(KERNEL_FILTER_FILE)
message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}")
endif()
if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
message(STATUS "Full path of heuristics problems file: ${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}")
if(DEFINED CUTLASS_NVMMH_URL)
message(STATUS "CUTLASS_NVVMH_URL is set. Fetching dependency")
include(FetchContent)
FetchContent_Declare(
nvmmh
URL ${CUTLASS_NVMMH_URL}
)
FetchContent_MakeAvailable(nvmmh)
FetchContent_GetProperties(nvmmh SOURCE_DIR nvmmh_dir)
set(CUTLASS_NVMMH_PATH "${nvmmh_dir}")
endif()
if(DEFINED CUTLASS_NVMMH_PATH)
message(STATUS "CUTLASS_NVMMH_PATH is set. Using package at: ${CUTLASS_NVMMH_PATH}")
set(CUTLASS_NVMMH_PY_DIR "${CUTLASS_NVMMH_PATH}/python/")
set(ENV{CUTLASS_NVMMH_SO_PATH} "${CUTLASS_NVMMH_PATH}/lib/libnvMatmulHeuristics.so")
endif()
endif()
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.")
set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).")
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 and SM100 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
endif()
################################################################################
@ -350,6 +394,10 @@ if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
endif()
if (CUTLASS_NVCC_ARCHS MATCHES 110f)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
endif()
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
#
@ -428,8 +476,6 @@ endif()
#
###################################################################################################
# Warnings-as-error exceptions and warning suppressions for Clang builds
if (CUTLASS_CLANG_HOST_COMPILE)
@ -704,9 +750,16 @@ target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
)
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
)
endif()
install(
DIRECTORY
${CUTLASS_INCLUDE_DIR}/

View File

@ -1,30 +0,0 @@
# Changelog for CuTe DSL API changes
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
* for loop
- Python built-in ``range`` now always generates IR and executes at runtime
- ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control
- Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range``
- **Experimental** Added ``pipelining`` control for compiler generated software pipeline code
* while/if
- ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate
- Deprecated ``cutlass.dynamic_expr``, please remove it
* Rename mbarrier functions to reduce ambiguity
* Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier`
* Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional.
* Introduce `cutlass.cute.arch.get_dyn_smem_size` api to get runtime dynamic shared memory size.
* Various API Support for SM100 BlockScaled Gemm
- Introduce BlockScaled MmaOps in [tcgen05/mma.py]([https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py]), and provide a `make_blockscaled_trivial_tiled_mma` function in [blackwell_helpers.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blackwell_helpers.py) to help construct a BlockScaled TiledMma.
- Introduce S2T CopyOps in [tcgen05/copy.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py).
- Introduce BlockScaled layout utilities in [blockscaled_layout.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blockscaled_layout.py) for creating the required scale factor layouts in global memory, shared memory and tensor memory.
* `cutlass.cute.compile` now supports compilation options. Refer to [JIT compilation options](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.html) for more details.
* `cutlass.cute.testing.assert_` now works for device JIT function. Specify `--enable-device-assertions` as compilation option to enable.
* `cutlass.cute.make_tiled_copy` is now deprecated. Please use `cutlass.cute.make_tiled_copy_tv` instead.
* Shared memory capacity query
- Introduce `cutlass.utils.get_smem_capacity_in_bytes` for querying the shared memory capacity.
- `<arch>_utils.SMEM_CAPACITY["<arch_str>"]` is now deprecated.
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
* Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``

120
README.md
View File

@ -1,9 +1,9 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 4.1.0
# CUTLASS 4.2.0
_CUTLASS 4.1.0 - July 2025_
_CUTLASS 4.2.0 - Sept 2025_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@ -27,14 +27,14 @@ native support of such data types) across NVIDIA's Volta, Turing, Ampere, Ada, H
To this rich ecosystem of C++ based kernel programming abstractions, CUTLASS 4 adds CUTLASS DSLs. These are Python native interfaces for writing high-performance CUDA kernels based on core CUTLASS and CuTe concepts without any performance compromises. This allows for a much smoother learning curve, orders of magnitude faster compile times, native integration with DL frameworks without writing glue code, and much more intuitive metaprogramming that does not require deep C++ expertise.
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions -- exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
CuTe DSL demonstrates optimal matrix multiply and other linear algebra operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Ampere, Hopper, and Blackwell architectures.
We believe it will become an indispensable tool for students, researchers, and performance
engineers alike flattening the learning curve of GPU programming, rapidly prototyping kernel
engineers alike -- flattening the learning curve of GPU programming, rapidly prototyping kernel
designs, and bringing optimized solutions into production.
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
@ -43,40 +43,94 @@ To get started quickly - please refer :
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
# What's New in CUTLASS 4.1
# What's New in CUTLASS 4.2
## CuTe DSL
* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems!
* More examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
* More Python versions are now supported for both x86-64 and aarch64, including
- Python 3.10, 3.11, 3.12, and 3.13
* Added new example and updated notebook to get started with CuTe DSL
- [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py)
- Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb)
+ Added a section for introducing the broadcast
* API updates
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
* Bug fixings and improvements
- Fixed ``cute.print_tensor`` for coordinate tensor
- Fixed `cute.print` for tuple of layouts
- Fixed frozen object is not properly updated after fully assigned in dynamic control flow
- Fixed assign tuple/list element in a dynamic control flow may cause compilation failure
- Improved error message when CUDA context is not initialized
- Improved docstring of congruent and weakly_congruent
## CUTLASS C++
* Support for Blackwell SM103 kernels for B300 GPUs.
- Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp)
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture:
- [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/).
- [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm).
* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM
- Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/).
* Support for Blackwell SM121 kernels for DGX Spark GPUs.
- Share the major codes with Blackwell SM120 kernels.
* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario.
- Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md).
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Add variable sequence length support for FMHA Backward kernel.
- Add varlen test support to Backward runner.
- Codes support empty batch sequences.
* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays.
- Add fused reduction kernel support for cutlass MLA.
- Add softmax skip correction.
- Support for GQA in FMHA backward kernel.
- Fix an issue where `get_unmasked_trip_count` may return a negative value.
- Fix an issue where mbarriers are initialized with a zero arrival count.
- Fix a corner case issue where the sequence length of q is not a multiple of tile_q.
- Remove tma padding for forward kernel inputs.
* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome.
* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell
- On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/).
- On Hopper, add K major scale factor support for SM90 blockwise kernels.
- On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size.
- On Hopper, grouped version supports the case when k = 0.
* Support for Blackwell SM100 fp4 gemv kernels.
- Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h).
- Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/)
* Support for Blackwell SM100 legacy mixed input GEMM kernels.
- Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp).
- Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp).
- Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/).
* Support for Blackwell SM100 cpasync kernel.
- Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp).
- Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp).
* Support Blackwell SM120 mixed input blockscaled grouped GEMM.
* Instantiating more Blackwell kernels in profiler.
- Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations.
- To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels.
- Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md).
* Fix some profiler issues:
- Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line.
- Fix some no output and timeout issues.
- Fix Pingpong Blockwise Hopper library generation.
* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110.
- For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs.
- For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid.
* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface.
- Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`.
- Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter.
- Added some support for running SM100 kernels via the Python interface.
* CuTe changes:
- Rewrite ArithTuple and ScaledBasis for robustness and clarity.
- Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX.
- Factor out `print_latex` and friends and rewrite.
- Factor out `print_svg` and friends and rewrite.
* Support Blackwell SM100 SIMT packed fp32x2 kernels.
* Support residual add for implicit gemm kernels.
* Various fixes for CUTLASS C++ Python interface's EVT tracer:
- Add verifier for sm90 to report the invalid input.
- When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges.
- Register operations of tanh, sigmoid, exp, gelu to the python ast frontend.
- Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback.
* Fix profiler bugs in exhaustive perf search.
- Fix incorrect cluster shape output issue when doing exhaustive search.
- Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders.
* Fix some profiler issues.
- Complete the reference for Blackwell blockwise gemm kernels.
- Fix incorrect regex logic for L1 test.
- Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/).
- Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support.
- Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels.
- Support fp16 accmulator for sm89 fp8 mma.
- Shorten `nullspace` implementation.
- Isolate and comment on `cosize` hacks.
- Important documentation correction: `E<0,1> == 1@0@1`.
* Fix some kernel issues:
- Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers.
- Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel.
* Add following unit tests:
- [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu)
- [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu)
- [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu)
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
CUTLASS team is working on a fix.
@ -170,7 +224,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|NVIDIA GeForce RTX 50x0 series |12.0|12.8|
## Target Architecture
@ -202,7 +256,7 @@ cmake .. -DCUTLASS_NVCC_ARCHS="100a"
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
products has a different compute capability than the one underpinning
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
NVIDIA Blackwell GeForce RTX 50 series GPUs (SM120). As a result, kernels
compiled for Blackwell SM100 architecture with arch conditional features
(using `sm100a`) are not compatible with RTX 50 series GPUs.

View File

@ -65,7 +65,12 @@ endfunction()
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
set(PROFILER_ARCH_LIST 100a 100f 103a 120a 120f 121a)
if (CUDA_VERSION VERSION_LESS 13.0)
list(APPEND PROFILER_ARCH_LIST 101a 101f)
else()
list(APPEND PROFILER_ARCH_LIST 110a 110f)
endif()
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")

View File

@ -45,7 +45,7 @@
cutlass::half_t
This is a numeric type implementing IEEE half-precision quantities. It is functional in host
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables harware-accelerated
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables hardware-accelerated
numeric conversion on x86-64 CPUs support F16C extensions. In device code, all available
hardware is used to implement conversion and numeric operations.

View File

@ -243,10 +243,11 @@ cudaError_t run_batched_gemm(bool use_array) {
const char* gemm_desc = use_array ? "array" : "strided batched";
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
// Arbitrary problem size
// Arbitrary matrix shape
int const m = 520;
int const n = 219;
int const k = 129;
int const batch_count = 17;
// A, B are non-transpose, column major

View File

@ -659,7 +659,7 @@ struct Testbed {
}
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
int64_t bytes = cutlass::bits_to_bytes(
int64_t bytes = cutlass::bits_to_bytes<int64_t>(
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
options.problem_size.m() * options.problem_size.n());

View File

@ -33,8 +33,8 @@
computing reference permutations of 4/5D tensors when source data is column-major.
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/coord.h"

View File

@ -40,14 +40,12 @@
Note that in general the fragment passed to the OutputOp could
span multiple rows but it does not happen with the configurations we have
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"

View File

@ -42,12 +42,10 @@
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"

View File

@ -38,10 +38,8 @@
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"

View File

@ -37,12 +37,10 @@
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/array.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/tensor_coord.h"

View File

@ -26,7 +26,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 90a)
cutlass_example_add_executable(
65_distributed_gemm
65_distributed_gemm.cu
)
endif()

View File

@ -129,7 +129,7 @@ using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_confi
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

View File

@ -132,12 +132,12 @@ constexpr int ScaleGranularityK = 128;
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<

View File

@ -145,7 +145,7 @@ using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularity
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
@ -402,12 +402,37 @@ void initialize(const OptionType &options) {
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
// If the current group's matrix has size 0, set the pointer to nullptr
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
ptr_A_host.at(i) = nullptr;
} else {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
}
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
ptr_B_host.at(i) = nullptr;
} else {
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
}
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
ptr_C_host.at(i) = nullptr;
} else {
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
}
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
ptr_D_host.at(i) = nullptr;
} else {
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
ptr_blockscale_A_host.at(i) = nullptr;
} else {
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
}
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
ptr_blockscale_B_host.at(i) = nullptr;
} else {
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
}
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
@ -546,10 +571,10 @@ bool verify(const OptionType &options) {
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
@ -598,11 +623,7 @@ bool verify(const OptionType &options) {
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
unused_t, // Aux
unused_t, // valpha
unused_t // vbeta
decltype(D)
> epilogue_params;
epilogue_params.C = C;
@ -639,6 +660,24 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
allocate(options);
initialize(options);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
@ -671,8 +710,7 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
}
// Run profiling loop
if (options.iterations > 0)
{
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
@ -686,25 +724,6 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
fflush(stdout);

View File

@ -132,8 +132,7 @@ using ElementCompute = float; // E
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
static constexpr int ScaleGranularityM = 1;
@ -142,13 +141,13 @@ static constexpr int ScaleGranularityK = 128;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
@ -407,12 +406,37 @@ void initialize(const OptionType &options) {
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
// If the current group's matrix has size 0, set the pointer to nullptr
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
ptr_A_host.at(i) = nullptr;
} else {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
}
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
ptr_B_host.at(i) = nullptr;
} else {
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
}
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
ptr_C_host.at(i) = nullptr;
} else {
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
}
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
ptr_D_host.at(i) = nullptr;
} else {
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
ptr_blockscale_A_host.at(i) = nullptr;
} else {
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
}
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
ptr_blockscale_B_host.at(i) = nullptr;
} else {
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
}
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
@ -551,10 +575,10 @@ bool verify(const OptionType &options) {
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
@ -637,10 +661,27 @@ bool verify(const OptionType &options) {
template <typename OptionType>
int run(OptionType &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
@ -695,27 +736,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
ScaleMsPerTile,
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
std::cout << " GBPS: " << result.gbps << std::endl;
fflush(stdout);
}
return 0;
@ -766,8 +790,8 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
std::cout << "Running tests with host problem shapes:" << std::endl;
run(options, true);
std::cout << "Running tests without host problem shapes:" << std::endl;
run(options, false);

View File

@ -44,6 +44,9 @@ set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0)
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
set(TEST_K_16B_ALIGNED --m=256 --n=512 --k=960 --groups=10 --iterations=0)
set(TEST_K_16B_ALIGNED_LARGE_GROUP --m=256 --n=512 --k=960 --groups=512 --iterations=0)
cutlass_example_add_executable(
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
@ -58,6 +61,8 @@ cutlass_example_add_executable(
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_K_16B_ALIGNED
TEST_K_16B_ALIGNED_LARGE_GROUP
)
# MSVC will fail to compile this example with the following error:

View File

@ -111,14 +111,14 @@ struct Options {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
if (m < 0) {
m = m_alignment * (rand() % (64 * alignment / m_alignment));
}
if (n < 1) {
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
if (n < 0) {
n = n_alignment * (rand() % (64 * alignment / n_alignment));
}
if (k < 1) {
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
if (k < 0) {
k = k_alignment * (rand() % (32 * alignment / k_alignment));
}
problem_sizes_after_alignment_host.push_back({m, n, k});
problem_sizes_host.push_back({m, n, k});

View File

@ -454,11 +454,12 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -640,11 +640,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
}
//
// Parse options

View File

@ -33,7 +33,7 @@ set(TEST_SWIZZLE_2 --swizzle=2)
set(TEST_SWIZZLE_5 --swizzle=5)
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
cutlass_example_add_executable(
70_blackwell_fp16_gemm
70_blackwell_fp16_gemm.cu

View File

@ -449,9 +449,9 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}

View File

@ -27,7 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
cutlass_example_add_executable(
71_blackwell_gemm_with_collective_builder
71_blackwell_gemm_with_collective_builder.cu

View File

@ -116,7 +116,7 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
@ -511,10 +511,10 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}
}
//
// Parse options

View File

@ -566,8 +566,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}

View File

@ -117,7 +117,7 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
@ -512,8 +512,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}

View File

@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
cutlass_example_add_executable(
72a_blackwell_nvfp4_bf16_gemm
72a_blackwell_nvfp4_bf16_gemm.cu

View File

@ -28,7 +28,7 @@
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
cutlass_example_add_executable(
73_blackwell_gemm_preferred_cluster
blackwell_gemm_preferred_cluster.cu

View File

@ -513,7 +513,7 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}

View File

@ -29,9 +29,9 @@
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
74_blackwell_gemm_streamk
blackwell_gemm_streamk.cu
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
cutlass_example_add_executable(
74_blackwell_gemm_streamk
blackwell_gemm_streamk.cu
)
endif()

View File

@ -556,10 +556,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -762,9 +762,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 10 && props.minor == 0)) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}

View File

@ -138,8 +138,7 @@ using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor
// Core kernel configurations
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
// Runtime Cluster Shape
@ -159,7 +158,7 @@ struct MMA2SMConfig {
};
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass,
ArchTag, OperatorClass,
typename MMA1SMConfig::MmaTileShape, ClusterShape,
Shape<_128,_64>,
ElementAccumulator, ElementAccumulator,
@ -169,7 +168,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
// , FusionOperation // Enable for SF Output
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass,
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
@ -187,7 +186,7 @@ using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass,
ArchTag, OperatorClass,
typename MMA2SMConfig::MmaTileShape, ClusterShape,
Shape<_128,_64>,
ElementAccumulator, ElementAccumulator,
@ -197,13 +196,13 @@ using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::Collective
// , FusionOperation // Enable for SF Output
>::CollectiveOp;
using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass,
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
typename MMA2SMConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
static_cast<int>(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>,
typename MMA2SMConfig::KernelSchedule
>::CollectiveOp;
using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal<
@ -233,7 +232,7 @@ using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFA> layout_SFB_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
@ -897,9 +896,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 10 && props.minor == 0)) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}

View File

@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS STREQUAL "100a")
cutlass_example_add_executable(
75_blackwell_grouped_gemm
75_blackwell_grouped_gemm.cu

View File

@ -504,10 +504,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -504,10 +504,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -500,10 +500,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -126,6 +126,7 @@ struct Options {
bool verbose = false;
bool causal = false;
bool causal_q_begin = true;
bool residual = false;
bool varlen = false;
bool persistent = false;
@ -266,6 +267,8 @@ struct Options {
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
std::string causal_type;
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
if (mask == "no" || mask == "") {
causal = residual = false;
if (varlen) {
@ -275,6 +278,11 @@ struct Options {
else if (mask == "causal") {
residual = false;
causal = true;
if(causal_type == "qend") {
causal_q_begin = false;
} else {
causal_q_begin = true;
}
}
else if (mask == "residual") {
residual = true;
@ -313,6 +321,7 @@ struct Options {
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --causal-type=<qbegin|qend> Causal mask type\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
@ -410,16 +419,16 @@ struct FwdRunner {
using ElementAccumulatorPV = float;
using ElementOut = cutlass::half_t;
// Q K D (B H)
// Q K D ((H_R, H_K) B)
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D ((H_R, H_K), B)
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D ((H_R, H_K), B)
using StrideV = StrideK;
using StrideO = StrideQ;
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q ((H_R, H_K), B)
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
@ -602,8 +611,8 @@ struct FwdRunner {
ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size);
@ -660,9 +669,9 @@ struct FwdRunner {
}
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_Q.reset(size(shape_QO));
buffer.block_K.reset(size(shape_KV));
buffer.block_V.reset(size(shape_KV));
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
@ -1078,7 +1087,11 @@ int main_single(int argc, char const **args) {
auto with_mask = [&](auto fn) {
if (options.causal) {
fn(CausalMask{});
if(options.causal_q_begin) {
fn(CausalMask{});
} else {
fn(CausalMask<false>{});
}
}
else if (options.residual) {
fn(ResidualMask{});

View File

@ -183,6 +183,9 @@ struct Options {
cmd.get_cmd_line_argument("h", h, -1);
if (h == -1) h = 2048 / d;
cmd.get_cmd_line_argument("h_k", h_k, -1);
if (h_k == -1) h_k = h;
varlen = cmd.check_cmd_line_flag("varlen");
cmd.get_cmd_line_argument("q", q, -1);
@ -298,6 +301,7 @@ struct Options {
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the B extent\n"
<< " --h=<int> Sets the H extent\n"
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
@ -405,25 +409,24 @@ struct BwdRunner {
#endif
using ElementAccumulator = float;
// Q K D (H B)
// Q K D D_VO ((H_R, H_K) B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
cute::tuple<int, int, int, int, cute::tuple<int, int>>
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
using StrideQ = TensorStride;
using StrideK = TensorStride;
using StrideV = TensorStride;
using StrideO = TensorStride;
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
// Backwards specific
using StrideDQ = TensorStride;
using StrideDK = TensorStride;
using StrideDV = TensorStride;
using StrideDO = TensorStride;
using StrideDQ = StrideQ;
using StrideDK = StrideK;
using StrideDV = StrideV;
using StrideDO = StrideO;
//
// Data members
@ -468,43 +471,15 @@ struct BwdRunner {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [H, B] = HB;
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,4>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,4>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,3,4>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,3,4>(problem_shape),
stride_O);
// keep going here! (this might be better in cursor)
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,4>(problem_shape),
stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
select<0,2,4>(problem_shape),
stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
select<1,2,4>(problem_shape),
stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
select<1,3,4>(problem_shape),
stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
select<0,3,4>(problem_shape),
stride_dO);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO);
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
@ -549,6 +524,9 @@ struct BwdRunner {
}
auto initialize_problem_shape(Options const& options) {
int h_r = options.h / options.h_k;
assert(options.h % options.h_k == 0);
if constexpr (kIsVarlen) {
int num_batches = options.b;
@ -599,14 +577,14 @@ struct BwdRunner {
ProblemShape problem_shape{
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
options.d, options.d_vo, {options.h, options.b}
options.d, options.d_vo, {{h_r, options.h_k}, options.b}
};
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1));
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1));
return cute::make_tuple(problem_shape, tensor_shape);
}
else {
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}};
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}};
return cute::make_tuple(problem_shape, problem_shape);
}
}
@ -616,22 +594,23 @@ struct BwdRunner {
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
auto [Q, K, D, D_VO, HB] = tensor_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
// for varlen, Q == total_Q, K == total_K, B = 1
// but in problem_shape, they've got to be max_Q/max_K, and B = B
auto shape_Q = make_shape(Q, D, make_shape(H, B));
auto shape_O = make_shape(Q, D_VO, make_shape(H, B));
auto shape_K = make_shape(K, D, make_shape(H, B));
auto shape_V = make_shape(K, D_VO, make_shape(H, B));
auto shape_LSE = make_shape(Q, make_shape(H, B));
auto shape_Q = make_shape(Q, D, HB);
auto shape_K = make_shape(K, D, HB);
auto shape_V = make_shape(K, D_VO, HB);
auto shape_O = make_shape(Q, D_VO, HB);
auto shape_LSE = make_shape(Q, HB);
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H));
stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H));
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K));
stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K));
stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K));
stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
stride_dQ = stride_Q;
stride_dK = stride_K;
@ -642,20 +621,23 @@ struct BwdRunner {
return size(make_shape(1ull, shape));
};
auto size_K = lsize(K * D * H_K * B);
auto size_V = lsize(K * D_VO * H_K * B);
block_Q.reset(lsize(shape_Q));
block_K.reset(lsize(shape_K));
block_V.reset(lsize(shape_V));
block_K.reset(size_K);
block_V.reset(size_V);
block_O.reset(lsize(shape_O));
block_LSE.reset(lsize(shape_LSE));
block_dQ.reset(lsize(shape_Q));
block_dK.reset(lsize(shape_K));
block_dV.reset(lsize(shape_V));
block_dK.reset(size_K);
block_dV.reset(size_V);
block_dO.reset(lsize(shape_O));
block_ref_dQ.reset(lsize(shape_Q));
block_ref_dK.reset(lsize(shape_K));
block_ref_dV.reset(lsize(shape_V));
block_ref_dK.reset(size_K);
block_ref_dV.reset(size_V);
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
@ -689,7 +671,7 @@ struct BwdRunner {
select<0,4>(problem_shape),
stride_LSE);
if (! options.skip_reference) {
if (not options.skip_reference) {
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
}
@ -816,11 +798,12 @@ struct BwdRunner {
runtime_ms /= static_cast<float>(options.iterations);
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask<false>> || std::is_same_v<ActiveMask, CausalForBackwardMask<true>> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
flops *= static_cast<double>(get<4,0>(problem_shape));
flops *= static_cast<double>(get<4,0,0>(problem_shape));
flops *= static_cast<double>(get<4,0,1>(problem_shape));
flops *= static_cast<double>(get<4,1>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
@ -1001,7 +984,7 @@ int main_single(int argc, char const **args) {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;

View File

@ -80,6 +80,7 @@ struct Options {
int iterations = 3;
bool verify = false;
bool verbose = false;
bool is_fused_reduction = false;
int sm_count = 0;
@ -139,9 +140,12 @@ struct Options {
if (b == 0) b = 1;
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
if (split_kv == 0) {
split_kv = 1;
}
cmd.get_cmd_line_argument("page", page, defaults.page);
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
is_var_split_kv = cmd.check_cmd_line_flag("var_split_kv");
if (page == -1) {
is_var_split_kv = false;
}
@ -149,6 +153,10 @@ struct Options {
if (is_var_split_kv == true) {
split_kv = max_split_kv;
}
is_fused_reduction = cmd.check_cmd_line_flag("fuse_reduction");
if (split_kv == 1) {
is_fused_reduction = false;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
@ -176,6 +184,8 @@ struct Options {
<< " --iterations=<int> Benchmarking iterations\n"
<< " --spread=<float> Relative spread away from K for paging\n"
<< " --split_kv=<int> Split KV factor\n"
<< " --fused_reduction Fuse the reduction operation\n"
<< " --var_split_kv Use varying split KV factor\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --sm-count Sets SM count rather than querying it\n"
@ -514,7 +524,8 @@ struct Runner {
stride_LSE},
hw_info,
options.split_kv,
options.is_var_split_kv ? block_split_kv.get() : nullptr
options.is_var_split_kv ? block_split_kv.get() : nullptr,
options.is_fused_reduction
};
if (options.split_kv < 0 && !options.is_var_split_kv) {
Operation::set_split_kv(arguments);
@ -724,13 +735,17 @@ void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
if (!options.is_fused_reduction || options.split_kv == 1) {
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
}
#elif FP16
name += " fp16";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
if (!options.is_fused_reduction || options.split_kv == 1) {
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
}
#endif
}

View File

@ -90,6 +90,7 @@ struct Options {
bool verbose = false;
bool causal = false;
bool causal_q_begin = true;
bool residual = false;
bool varlen = false;
bool persistent = false;
@ -231,6 +232,8 @@ struct Options {
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
std::string causal_type;
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
if (mask == "no" || mask == "") {
causal = residual = false;
if (varlen) {
@ -240,6 +243,11 @@ struct Options {
else if (mask == "causal") {
residual = false;
causal = true;
if(causal_type == "qend") {
causal_q_begin = false;
} else {
causal_q_begin = true;
}
}
else if (mask == "residual") {
residual = true;
@ -279,6 +287,7 @@ struct Options {
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --causal-type=<qbegin|qend> Causal mask type\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
@ -581,8 +590,8 @@ struct MlaFwdRunner {
ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size);
@ -642,9 +651,9 @@ struct MlaFwdRunner {
}
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_Q), kIsVarlen ? D_latent_rope*SQ*H : 0);
buffer.block_K.reset(size(shape_K), kIsVarlen ? D_latent_rope*SK*H_K : 0);
buffer.block_V.reset(size(shape_V), kIsVarlen ? D*SK*H_K : 0);
buffer.block_Q.reset(size(shape_Q));
buffer.block_K.reset(size(shape_K));
buffer.block_V.reset(size(shape_V));
buffer.block_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
buffer.block_ref_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0);
@ -840,7 +849,8 @@ struct MlaFwdRunner {
flops *= static_cast<double>(size<3,1>(problem_shape));
}
flops *= 2.0 * (std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
flops *= 2.0 * (std::is_same_v<ActiveMask, CausalMask<false>> ||
std::is_same_v<ActiveMask, CausalMask<true>> ? 0.5 : 1.0);
flops *= static_cast<double>(size<3,0>(problem_shape));
double flops0 = flops * static_cast<double>(size<2, 0>(problem_shape) + size<2, 1>(problem_shape));
@ -1013,7 +1023,11 @@ int main_single(int argc, char const **args) {
auto with_mask = [&](auto fn) {
if (options.causal) {
fn(CausalMask<false>{});
if(options.causal_q_begin) {
fn(CausalMask{});
} else {
fn(CausalMask<false>{});
}
}
else if (options.residual) {
fn(ResidualMask{});

View File

@ -59,6 +59,16 @@ set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
@ -75,6 +85,15 @@ set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --d
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_MLA_FWD_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
set(TEST_MLA_FWD_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
@ -87,6 +106,9 @@ set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
foreach(PREC fp8 fp16)
@ -116,6 +138,14 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
TEST_VARLEN_15
TEST_VARLEN_16
TEST_VARLEN_17
TEST_VARLEN_18
TEST_VARLEN_19
TEST_VARLEN_20
TEST_VARLEN_21
TEST_VARLEN_22
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
@ -139,6 +169,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
TEST_MLA_SEP_REDUCTION
TEST_MLA_FUSE_REDUCTION
)
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
@ -149,6 +181,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
TEST_MLA_SEP_REDUCTION
TEST_MLA_FUSE_REDUCTION
)
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
@ -207,6 +241,14 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_MLA_FWD_VARLEN_12
TEST_MLA_FWD_VARLEN_13
TEST_MLA_FWD_VARLEN_14
TEST_MLA_FWD_VARLEN_15
TEST_MLA_FWD_VARLEN_16
TEST_MLA_FWD_VARLEN_17
TEST_MLA_FWD_VARLEN_18
TEST_MLA_FWD_VARLEN_19
TEST_MLA_FWD_VARLEN_20
TEST_MLA_FWD_VARLEN_21
TEST_MLA_FWD_VARLEN_22
)
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})

View File

@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
The kernel and collective layer are then formulated to be fmha-specific.
@ -67,6 +67,8 @@ For detailed information on how to invoke them, check out either the tests in `C
to simplify the sample, clarified that `fmha_gen` sample only supports head
dim 128.
* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

View File

@ -203,13 +203,12 @@ struct CausalMask : NoMask {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
@ -222,12 +221,12 @@ struct CausalMask : NoMask {
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape));
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
}
}
@ -277,9 +276,10 @@ struct CausalMask : NoMask {
}
};
struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
template<bool kIsQBegin = true>
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
using Base = CausalMask<true>;
using Base = CausalMask<kIsQBegin>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
@ -296,10 +296,15 @@ struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
int offset_q = 0;
if constexpr (!kIsQBegin) {
offset_q = get<1>(problem_size) - get<0>(problem_size);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}

View File

@ -534,14 +534,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);

View File

@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = problem_shape;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = problem_shape;
}
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);

View File

@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
}
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);

View File

@ -60,6 +60,38 @@ template<
class Mask
>
class Sm100FmhaBwd {
private:
template <typename T>
constexpr static auto to_bwd_shape(T shape) {
if constexpr (IsMla) { // remove GQA mode
constexpr int R = decltype(rank(shape))::value;
auto HB = get<R-1>(shape);
auto rest = take<0,R-1>(shape);
return append(rest, make_shape(size<0>(HB), get<1>(HB)));
}
else {
return shape;
}
}
template <typename T>
constexpr static auto to_bwd_stride(T stride) {
if constexpr (IsMla) { // remove GQA mode
constexpr int R = decltype(rank(stride))::value;
auto HB = get<R-1>(stride);
auto rest = take<0,R-1>(stride);
if constexpr (is_same_v<remove_cv_t<decltype(get<0,0>(HB))>, _0>) {
return append(rest, make_stride(get<0,1>(HB), get<1>(HB)));
}
else {
return append(rest, make_stride(get<0,0>(HB), get<1>(HB)));
}
}
else {
return stride;
}
}
public:
/// Argument structure: User API
struct Arguments {
@ -67,26 +99,26 @@ public:
ProblemShape problem_shape;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_Q;
const Element* ptr_K;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_K;
const Element* ptr_V;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_V;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dO;
Element* ptr_dQ;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dQ;
Element* ptr_dK;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dK;
Element* ptr_dV;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dV;
ElementAccumulator softmax_scale;
@ -106,9 +138,10 @@ public:
>
>;
using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{}));
using OperationMla = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask
>
>;
@ -134,10 +167,11 @@ private:
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_shape,
@ -154,14 +188,15 @@ private:
using namespace cute;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
return typename OperationConvert::Arguments {
args.problem_shape,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
@ -171,22 +206,22 @@ private:
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_shape,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
scaled_lse, stride_scaled_lse,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ,
to_bwd_shape(args.problem_shape),
{ args.ptr_Q, to_bwd_stride(args.stride_Q),
args.ptr_K, to_bwd_stride(args.stride_K),
args.ptr_V, to_bwd_stride(args.stride_V),
args.ptr_dO, to_bwd_stride(args.stride_dO),
scaled_lse, to_bwd_stride(stride_scaled_lse),
sum_OdO, to_bwd_stride(stride_sum_OdO),
dQ_acc, to_bwd_stride(stride_dQ),
args.softmax_scale },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
args.ptr_dV, to_bwd_stride(args.stride_dV) },
args.hw_info
};
}
@ -220,7 +255,7 @@ public:
static size_t
get_workspace_size(Arguments const& args) {
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0;
@ -240,7 +275,7 @@ public:
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
@ -269,7 +304,7 @@ public:
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);

View File

@ -127,7 +127,11 @@ public:
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
if (args.is_fused_reduction && split_wave_aware > 1) {
args.split_kv = std::min(split_wave_aware, static_cast<int>(sm_count/2));
} else {
args.split_kv = split_wave_aware;
}
}
/// Determines whether the GEMM can execute the given problem.
@ -273,11 +277,33 @@ public:
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
auto [H, K, D, B] = params.fmha_params.problem_shape;
auto [D_latent, D_rope] = D;
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;
}
auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int);
uint8_t* ws = reinterpret_cast<uint8_t*>(params.fmha_params.epilogue.ptr_lse_exchange_buff);
result = cudaMemsetAsync(ws, 0, total_bytes, stream);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;;
}
}
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
@ -298,7 +324,7 @@ public:
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);

View File

@ -46,18 +46,18 @@ struct FmhaKernelBwdConvert {
ProblemShape problem_shape;
const ElementAcc* ptr_src_dQ;
tuple<int, _1, tuple<int, int>> stride_src_dQ;
tuple<int, _1, tuple<tuple<int, int>, int>> stride_src_dQ;
const ElementAcc* ptr_src_dK;
tuple<int, _1, tuple<int, int>> stride_src_dK;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dK;
const ElementAcc* ptr_src_dV;
tuple<int, _1, tuple<int, int>> stride_src_dV;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
tuple<int, _1, tuple<tuple<int, int>, int>> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, _1, tuple<int, int>> stride_dest_dK;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, _1, tuple<int, int>> stride_dest_dV;
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dV;
ElementAcc scale = 1.0;
};
@ -104,8 +104,8 @@ struct FmhaKernelBwdConvert {
template<class StrideSrc, class StrideDest, class Count>
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
int seqlen = count;
if constexpr (is_variable_length_v<decltype(count)>) {

View File

@ -46,18 +46,18 @@ struct FmhaKernelBwdSumOdO {
ProblemShape problem_shape;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_O;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_dO;
ElementAcc* ptr_sum_OdO;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_sum_OdO;
const ElementAcc* ptr_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_lse;
ElementAcc* ptr_scaled_lse = nullptr;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_scaled_lse;
ElementAcc sum_odo_scale = 1.0;
ElementAcc lse_scale = 1.0;
@ -104,11 +104,11 @@ struct FmhaKernelBwdSumOdO {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
auto problem_q = get<0>(params.problem_shape);
int seqlen_q = problem_q;

View File

@ -119,13 +119,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr int kStages = 2;
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
using TensorStrideContiguousK = Stride<int, _1, Stride<Stride<int,int>, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<Stride<int,int>, int>>;
using TensorStrideContiguousK_GQA = Stride<int, _1, Stride<Stride<_0,int>, int>>;
using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride<Stride<_0,int>, int>>;
// compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK_GQA, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDQK>,
@ -137,7 +139,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
// compute dP
using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStrideContiguousK, Alignment,
Element, TensorStrideContiguousK_GQA, Alignment,
Element, TensorStrideContiguousK, Alignment,
ElementAcc,
Shape<TileShapeK, TileShapeQ, TileShapeDVO>,
@ -177,7 +179,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// somewhat arbitrary since we dump to smem, need to agree with the previous one
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN, Alignment,
Element, TensorStrideContiguousMN_GQA, Alignment,
ElementAcc,
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
@ -278,15 +280,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
using TensorStride = TensorStrideContiguousK; // S D (H B)
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
using TensorStride_GQA = TensorStrideContiguousK_GQA;
using RowTensorStride = Stride<_1, Stride<Stride<int, int>, int>>; // S (H B)
struct MainloopArguments {
const Element* ptr_q;
TensorStride stride_q;
const Element* ptr_k;
TensorStride stride_k;
TensorStride_GQA stride_k;
const Element* ptr_v;
TensorStride stride_v;
TensorStride_GQA stride_v;
const Element* ptr_do;
TensorStride stride_do;
@ -308,7 +311,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}),
SmemLayoutDQ{}(_, _, _0{})
));
@ -322,9 +325,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
struct EpilogueArguments {
Element* ptr_dk;
TensorStride stride_dk;
TensorStride_GQA stride_dk;
Element* ptr_dv;
TensorStride stride_dv;
TensorStride_GQA stride_dv;
};
struct Arguments {
@ -346,7 +349,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static bool can_implement(Arguments const& args) {
auto [Q, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
auto [H_R, H_K] = H;
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) {
return false;
}
if (D % Alignment != 0 || D_VO % Alignment != 0) {
@ -432,7 +436,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
@ -447,6 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
auto [Q, K, D, D_VO, HB] = problem_shape;
int iter_index = iter_start;
using X = Underscore;
@ -590,6 +596,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_index += 1;
while (iter_count > 0) {
if (iter_index == iter_end) {
iter_index = iter_start;
get<0,0>(blk_coord_batch) += 1;
}
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
@ -660,7 +671,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
@ -1119,7 +1131,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
BlkCoord const& blk_coord,
BlkOffset const& blk_offset,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
EpilogueArguments const& epilogue_args,
@ -1141,6 +1154,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto [Q, K, D, D_VO, HB] = problem_shape;
int iter_index = iter_start;
// in tmem, S & P overlap
// and dP and dQ overlap
@ -1201,6 +1215,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tTR_cST_p = thread_t2r.partition_D(cST);
Tensor tTR_cST = split_wg(tTR_cST_p);
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
// Tensor tTR_tST_p = thread_t2r.partition_S(tSTtST);
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
@ -1224,8 +1239,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto tRT_cST_p = thread_r2t.partition_S(tDVcST);
auto tRT_cST = split_wg(tRT_cST_p);
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
int last_iter = iter_count - 1 + iter_index;
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape);
CUTLASS_PRAGMA_NO_UNROLL
while (iter_count > 0) {
@ -1245,13 +1259,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == iter_start);
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k);
}
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
@ -1372,6 +1393,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count -= 1;
iter_index += 1;
if (iter_index == iter_end) {
iter_index = iter_start;
}
}
epilogue(
@ -1384,7 +1408,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
CUTLASS_DEVICE void reduce(
BlkCoord const& blk_coord,
ProblemShape_ const& problem_shape,
int iter_index,
int iter_start,
int iter_end,
int iter_count,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
@ -1397,6 +1422,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using X = Underscore;
auto [Q, K, D, D_VO, HB] = problem_shape;
int iter_index = iter_start;
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
@ -1408,7 +1434,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
(_, _, _, _0{}, blk_coord_batch);
(_, _, _, _0{}, _);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
@ -1419,7 +1445,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
@ -1465,7 +1490,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
}
@ -1474,11 +1499,18 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count -= 1;
iter_index += 1;
if (iter_index == iter_end) {
iter_index = iter_start;
get<0,0>(blk_coord_batch) += 1;
}
}
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
@ -1676,21 +1708,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
pipeline_init_wait(size(ClusterShape{}));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z));
auto [problem_shape, blk_offset] = apply_variable_length_offset(
params.problem_shape,
blk_coord
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
}
iter_count -= iter_start;
int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape);
if (iter_count <= 0) {
epilogue_clear(
@ -1711,6 +1745,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_offset,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
params.mainloop_params,
@ -1732,6 +1767,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_coord,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
shared_storage.tensors,
@ -1754,6 +1790,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_offset,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
params.epilogue,
@ -1785,6 +1822,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
blk_coord,
problem_shape,
iter_start,
iter_end,
iter_count,
params.mainloop,
params.mainloop_params,
@ -1801,6 +1839,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
/* no-op */
}
#endif
}
static dim3 get_block_shape() {
@ -1811,7 +1850,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
static dim3 get_grid_shape(Params const& params) {
auto [Q, K, D, D_VO, HB] = params.problem_shape;
auto [H, B] = HB;
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
auto [H_R, H_K] = H;
dim3 grid(ceil_div(K, TileShapeK{}), H_K, B);
return grid;
}
};

View File

@ -1230,9 +1230,16 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
};
bool leading_causal_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
int kv_left = get<1>(blk_coord) * TileShapeK{};
int kv_right = kv_left + TileShapeK{} - 1;
int q_left = iter_index * TileShapeQ{} + offset;
int q_right = q_left + TileShapeQ{} - 1;
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
}
bool trailing_residual_masking = false;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
@ -1473,6 +1480,9 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
@ -1677,9 +1687,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
);
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
int offset = get<1>(problem_shape) - get<0>(problem_shape);
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
}
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
return;
@ -1795,6 +1807,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
/* no-op */
}
#endif
}
static dim3 get_block_shape() {

View File

@ -251,6 +251,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
TileScheduler tile_scheduler{params.tile_scheduler};
@ -465,6 +468,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Correction) {
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
@ -476,6 +481,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
has_valid = true;
if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty(
blk_coord,
@ -505,16 +512,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
if constexpr (NumWarpsEpilogue == 0) {
static_assert(NumWarpsCorrection == 1);
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
if (has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
}
else if (role == WarpRole::MMA) {
warpgroup_reg_set<NumRegsOther>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
bool allocated = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
@ -527,6 +535,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (!allocated) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
allocated = true;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
@ -580,6 +594,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Epilogue) {
warpgroup_reg_set<NumRegsOther>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
@ -591,6 +607,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
has_valid = true;
epilogue.store(
blk_coord, logical_problem_shape,
params.epilogue, params.problem_shape,
@ -602,8 +620,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static_assert(NumWarpsEpilogue <= 1);
if constexpr (NumWarpsEpilogue == 1) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
if(has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
}
@ -612,6 +632,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
/* no-op, donate regs and exit */
}
#endif
}
};

View File

@ -247,6 +247,9 @@ struct Sm100FmhaGenKernelWarpspecialized {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
TileScheduler tile_scheduler{params.tile_scheduler};
@ -365,7 +368,7 @@ struct Sm100FmhaGenKernelWarpspecialized {
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
}
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp);
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
shared_storage.pipelines.corr_epi,
pipeline_corr_epi_params,
@ -569,6 +572,7 @@ struct Sm100FmhaGenKernelWarpspecialized {
/* no-op, donate regs and exit */
}
#endif
}
};

View File

@ -101,8 +101,17 @@ struct Sm100FmhaMlaReductionKernel {
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
if (local_split_kv == 1) return;
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
@ -113,12 +122,6 @@ struct Sm100FmhaMlaReductionKernel {
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
@ -130,17 +133,18 @@ struct Sm100FmhaMlaReductionKernel {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
lse_max = fmax(local_lse[i], lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
lse_max = fmax(__shfl_xor_sync(0xffffffff, lse_max, offset), lse_max);
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;

View File

@ -36,6 +36,7 @@
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/barrier.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
@ -44,6 +45,7 @@
#include "gather_tensor.hpp" // from examples/common
#include "common/pow_2.hpp"
#include "sm100_mla_tile_scheduler.hpp"
namespace cutlass::fmha::kernel {
@ -87,8 +89,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
using TileShapeR = tuple_element_t<1, TileShapeD>;
static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim");
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
using TensorStride = Stride<int64_t, _1, int64_t>;
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
using TensorStride = Stride<int64_t, _1, int64_t>;
using TmemAllocator = cute::conditional_t<kIs2Sm, cute::TMEM::Allocator2Sm, cute::TMEM::Allocator1Sm>;
static_assert(TileShapeH{} == 128);
@ -181,10 +183,13 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB;
using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int<IterationsPV_K>{}, _2{})));
using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{}));
using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}));
static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v<Element>);
static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v<Element>);
// pre-condition for overlapped smem staging
static_assert(kBytesLoadKC == kBytesLoadVC);
static_assert(StagesQK == StagesPV);
@ -226,7 +231,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKC>> smem_kc;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutVC>> smem_vc;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(2048) cute::array<ElementOut, size(TileShapeAcc{})> smem_acc;
};
};
struct SharedStorage {
@ -280,6 +288,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
KernelHardwareInfo hw_info;
int split_kv = -1;
int* ptr_split_kv = nullptr;
bool is_fused_reduction = false;
};
using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A;
@ -288,6 +297,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B;
using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B;
using GmemLayout = decltype(make_layout(Shape<int,int,int>{}, Stride<int64_t, _1, int64_t>{}));
using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{}));
using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
make_tensor(recast_ptr<ElementOut>(nullptr), GmemLayout{}), SmemLayout{}));
struct MainloopParams {
TmaLoadQLatent tma_load_q_latent;
TmaLoadQRope tma_load_q_rope;
@ -306,6 +321,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
Stride<_1, int> stride_lse;
Stride<_1, int> stride_lse_acc;
ElementAcc output_scale = 1.0f;
ElementLSE* ptr_lse_exchange_buff = nullptr;
int* ptr_lse_max_exchange_buff = nullptr;
int* ptr_lock = nullptr; // semaphore
TmaReduceSum tma_reduce_sum;
};
struct Params {
@ -316,6 +335,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
typename TileScheduler::Params tile_scheduler;
int split_kv = -1;
int* ptr_split_kv = nullptr;
bool is_fused_reduction = false;
};
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
@ -380,11 +400,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
epilogue_params.ptr_o = args.epilogue.ptr_o;
epilogue_params.stride_o = args.epilogue.stride_o;
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
epilogue_params.stride_lse = args.epilogue.stride_lse;
epilogue_params.output_scale = args.epilogue.output_scale;
epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr<ElementOut>(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{});
if (args.split_kv > 1) {
if (!args.is_fused_reduction && args.split_kv > 1) {
ElementAcc* ptr_o_acc = reinterpret_cast<ElementAcc*>(workspace);
ElementLSE* ptr_lse_acc = reinterpret_cast<ElementLSE*>(ptr_o_acc + H * L * args.split_kv * B);
epilogue_params.ptr_o_acc = ptr_o_acc;
@ -392,10 +413,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
epilogue_params.stride_o_acc = make_tuple(static_cast<int64_t>(0 + L) * args.split_kv, _1{}, static_cast<int64_t>(0 + H * L) * args.split_kv);
epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv);
} else if (args.is_fused_reduction && args.split_kv > 1) {
ElementLSE* ptr_lse_exchange_buff = reinterpret_cast<ElementLSE*>(workspace);
epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff;
int* ptr_lse_max_exchange_buff = reinterpret_cast<int*>(ptr_lse_exchange_buff + H * B);
epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff;
int* ptr_lock = ptr_lse_max_exchange_buff + H * B;
epilogue_params.ptr_lock = ptr_lock;
}
return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params,
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv};
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv),
args.split_kv, args.ptr_split_kv, args.is_fused_reduction};
}
static size_t get_workspace_size(Arguments const& args) {
@ -403,10 +432,29 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto split_kv = args.split_kv;
return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
size_t workspace_size {0};
if (args.is_fused_reduction && args.split_kv > 1) {
// one exchange buffer for LSE max and another buffer for total LSE
// two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127
workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int);
} else if (!args.is_fused_reduction && args.split_kv > 1) {
workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
}
return workspace_size;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
Arguments const& args, void* ws, cudaStream_t stream) {
auto workspace_size = get_workspace_size(args);
if (args.is_fused_reduction && args.split_kv > 1) {
auto result = cudaMemsetAsync(ws, 0, workspace_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;;
}
}
return Status::kSuccess;
}
@ -448,11 +496,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n";
return false;
}
if (args.is_fused_reduction && args.split_kv > 1) {
if (2 * args.split_kv > args.hw_info.sm_count ||
std::is_same_v<TileScheduler, Sm100MlaIndividualTileScheduler>) {
return false;
}
}
return true;
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) {
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
TileScheduler tile_scheduler(params.tile_scheduler);
@ -746,7 +803,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_consumer_state,
pipeline_p_mma, pipeline_p_mma_producer_state,
pipeline_mma_o, pipeline_mma_o_consumer_state,
local_split_kv
local_split_kv,
params.is_fused_reduction
);
}
@ -759,6 +817,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
#endif
}
template<class BlkCoord>
@ -1777,7 +1836,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
if (epilogue_args.ptr_o_acc != nullptr) {
if (split_kv > 1) {
using ElementOutAcc = ElementAcc;
constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v<ElementOutAcc>;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc);
@ -1806,16 +1866,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
copy(tTR_rO_src, tR2G_rO_dst);
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
if (get<1>(cta_coord) == 0) {
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
}
}
}
}
else {
@ -1845,24 +1909,165 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
copy(tTR_rO_src, tR2G_rO_dst);
if (get<1>(cta_coord) == 0) {
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
}
}
}
}
}
template<class BlkCoord>
CUTLASS_DEVICE ElementLSE epilogue_lse_reduction(
ElementAcc& row_max,
ElementAcc& row_sum,
BlkCoord const& cta_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueParams const& epilogue_args,
int const& local_split_kv) {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
auto wait = [](int* lock, int count) {
__threadfence();
if (threadIdx.x == 0) {
atomicAdd(lock, 1);
while (atomicCAS(lock, count, count) != count) {};
}
__threadfence();
Sync::sync();
};
const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord);
if (! kIs2Sm || (threadIdx.x < 64)) {
atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse));
}
wait(local_lock, local_split_kv);
auto global_lse_max = static_cast<ElementLSE>(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x));
Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
if (! kIs2Sm || (threadIdx.x < 64)) {
atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max));
}
wait(local_lock, 2*local_split_kv);
const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x);
const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementLSE>::infinity() :
cutlass::fast_log(sum_lse) + global_lse_max;
const auto lse_scale = expf(lse - global_lse);
if (epilogue_args.ptr_lse != nullptr) {
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// write out the global LSE
if (! kIs2Sm || (threadIdx.x < 64)) {
gLSE(threadIdx.x) = global_lse;
}
}
return lse_scale;
}
template<class BlkCoord>
CUTLASS_DEVICE void epilogue_reduction(
ElementAcc& row_max,
ElementAcc& row_sum,
BlkCoord const& blk_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueParams const& epilogue_args,
TensorStorage& shared_tensors,
int const& local_split_kv,
ElementLSE const& lse_scale) {
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
TiledMmaPV tiled_mma_pv;
Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{})));
CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{});
CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{});
using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination<ElementOut, 1, ElementAcc, ElementAcc, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale});
CUTLASS_PRAGMA_UNROLL
for(int k = 0; k < IterationsPV_N; ++k) {
auto cta_coord = replace<1>(blk_coord, k);
uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO);
tOtO.data() = tmem_o;
Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{});
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord));
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
auto thread_idx = threadIdx.x % size(tiled_t2r);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_gO = thread_t2r.partition_D(gO);
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast<ElementOut*>(shared_tensors.smem_acc.begin())), SmemLayout{});
Tensor tTR_sO = thread_t2r.partition_D(sO);
Sync::sync();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i++) {
tTR_sO(i) = epilogue_op(tTR_rAcc(i));
}
tma_store_fence();
Sync::sync();
auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{});
auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO));
auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord));
if (threadIdx.x % kNumThreads == 0) {
copy(epilogue_args.tma_reduce_sum,
tma_reduce_sum_per_cta.partition_S(sO),
tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta));
tma_store_arrive();
}
tma_store_wait<0>();
}
}
template<class CtaCoord>
CUTLASS_DEVICE void compute(
@ -1877,7 +2082,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
typename PipelineP::PipelineState& pipeline_p_mma_producer_state,
PipelineO& pipeline_mma_o,
typename PipelineO::PipelineState& pipeline_mma_o_consumer_state,
int const& split_kv) {
int const& split_kv,
bool const& is_fused_reduction) {
auto [H, K, D, B] = problem_shape;
@ -1987,17 +2193,38 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
// epilogue
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IterationsPV_N; j++) {
epilogue(
row_max, row_sum,
replace<1>(cta_coord, j), problem_shape,
mainloop_args, epilogue_args, shared_tensors,
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv
const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
if (!is_fused_reduction || actual_split_kv == 1) {
// epilogue
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IterationsPV_N; j++) {
epilogue(
row_max, row_sum,
replace<1>(cta_coord, j), problem_shape,
mainloop_args, epilogue_args, shared_tensors,
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO),
actual_split_kv
);
}
} else {
const ElementLSE lse_scale =
epilogue_lse_reduction(
row_max, row_sum,
cta_coord,
problem_shape,
mainloop_args, epilogue_args,
actual_split_kv
);
epilogue_reduction(row_max, row_sum,
cta_coord,
problem_shape,
mainloop_args, epilogue_args,
shared_tensors,
actual_split_kv,
lse_scale
);
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state);
++pipeline_mma_o_consumer_state;

View File

@ -142,8 +142,8 @@ struct Sm100MlaPersistentTileScheduler {
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}

View File

@ -56,12 +56,12 @@ void __global__ fmha_bwd_reference_dQ_kernel(
using namespace cutlass::fmha::collective;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [problem_shape, offset] = apply_variable_length_offset(
@ -79,9 +79,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in);
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
ElementAcc acc_qk = 0;
ElementAcc acc_dov = 0;
ElementAcc acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
@ -94,20 +94,22 @@ void __global__ fmha_bwd_reference_dQ_kernel(
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_K] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
mS[idx_K] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_K
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
ElementAcc acc = 0;
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
acc += mS[idx_K] * ElementAccumulator(mK(idx_K, idx_D, idx_L));
ElementAcc rK = mK(idx_K, idx_D, idx_L);
ElementAcc rDS = mS[idx_K];
acc += rDS * rK;
}
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
} // for idx_D
@ -135,62 +137,83 @@ void __global__ fmha_bwd_reference_dK_kernel(
using namespace cutlass::fmha::collective;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [H, B] = get<4>(problem_shape_in);
auto [H_R, H_K] = H;
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDK = domain_offset(select<1,2,4>(offset), mDK_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
} // for idx_D0
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in);
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
ElementAcc acc_dk = 0;
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
ElementAcc acc_dov = 0;
ElementAcc acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
acc_qk += rQ * rK;
} // for idx_D0
for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) {
ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB);
ElementAcc rV = mV(idx_K, idx_D1, coord_HB);
ElementAcc rO = mO(idx_Q, idx_D1, coord_HB);
acc_dov += rDO * rV;
acc_doo += rDO * rO ;
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_Q
__syncthreads();
int idx_D = threadIdx.x;
if (idx_D < D) {
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB);
ElementAcc rDS = mS[idx_Q];
acc_dk += rDS * rQ;
}
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
__syncthreads();
} // for idx_H_R
mS[idx_Q] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
} // for idx_Q
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
acc += mS[idx_Q] * ElementAccumulator(mQ(idx_Q, idx_D, idx_L));
}
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
} // for idx_D
int idx_D = threadIdx.x;
if (idx_D < D) {
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
mDK(idx_K, idx_D, coord_HB) = static_cast<typename TensorDK::value_type>(acc_dk);
}
} // for idx_K
} // for idx_L
} // for idx_HB
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -216,54 +239,71 @@ void __global__ fmha_bwd_reference_dV_kernel(
using ElementAcc = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in)));
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
auto [H, B] = get<4>(problem_shape_in);
auto [H_R, H_K] = H;
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
auto [problem_shape, offset] = apply_variable_length_offset(
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
problem_shape_in,
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
);
// problem_shape = problem_shape_in;
// offset = repeat_like(problem_shape_in, _0{});
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
auto mDV = domain_offset(select<1,3,4>(offset), mDV_in);
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
acc_qk += rQ * rK;
} // for idx_D0
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in);
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
ElementAcc acc_dv = 0;
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
ElementAcc acc_qk = 0;
mS[idx_Q] = expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L));
} // for idx_Q
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
acc_qk += rQ * rK;
} // for idx_D0
__syncthreads();
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) {
ElementAcc acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
ElementAcc rS = static_cast<Element>(mS[idx_Q]);
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
acc += rS * rDO;
}
mDV(idx_K, idx_D, idx_L) = static_cast<typename TensorDV::value_type>(acc);
} // for idx_D
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)));
} // for idx_Q
__syncthreads();
int idx_D_VO = threadIdx.x;
if (idx_D_VO < D_VO) {
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB);
ElementAcc rP = mS[idx_Q];
acc_dv += rP * rDO;
}
} // for idx_D
__syncthreads();
} // for idx_H_R
int idx_D_VO = threadIdx.x;
if (idx_D_VO < D_VO) {
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
mDV(idx_K, idx_D_VO, coord_HB) = static_cast<typename TensorDV::value_type>(acc_dv);
}
} // for idx_K
} // for idx_L
}
@ -288,7 +328,7 @@ void fmha_bwd_reference_dQ(
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorLSE::value_type);
int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type);
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
}
@ -310,9 +350,12 @@ void fmha_bwd_reference_dK(
using namespace cute;
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
auto [K, D, HB] = mDK.shape();
auto [H, B] = HB;
auto [H_R, H_K] = H;
dim3 grid(K, H_K * B, 1);
dim3 block(std::max(D, 256));
int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type);
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
}
@ -334,9 +377,12 @@ void fmha_bwd_reference_dV(
using namespace cute;
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
auto [K, D_VO, HB] = mDV.shape();
auto [H, B] = HB;
auto [H_R, H_K] = H;
dim3 grid(K, H_K * B, 1);
dim3 block(std::max(D_VO, 256));
int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type);
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}

View File

@ -117,7 +117,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
cutlass::epilogue::NoSmemWarpSpecialized2Sm
cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm
>::CollectiveOp;
// Build the mainloop

View File

@ -88,7 +88,7 @@
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -189,7 +189,7 @@ cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
@ -283,7 +283,7 @@ struct Result
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -489,19 +489,28 @@ int run(Options &options)
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
@ -509,8 +518,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
return 0;
}
@ -530,9 +539,9 @@ int main(int argc, char const **args) {
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}

View File

@ -86,7 +86,7 @@
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -217,7 +217,7 @@ cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_refer
// Matrix-wide normalization constant
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
@ -311,7 +311,7 @@ struct Result
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -536,19 +536,28 @@ int run(Options &options)
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
@ -556,8 +565,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
return 0;
}
@ -577,9 +586,9 @@ int main(int argc, char const **args) {
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}

View File

@ -88,7 +88,7 @@
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -189,7 +189,7 @@ cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
@ -283,7 +283,7 @@ struct Result
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -489,19 +489,28 @@ int run(Options &options)
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 100.
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
@ -509,8 +518,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
return 0;
}
@ -530,9 +539,9 @@ int main(int argc, char const **args) {
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}

View File

@ -97,7 +97,7 @@ using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -263,7 +263,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
// NormConst is a single device-side constant value, its not per-batch or per-group
cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
@ -466,7 +466,7 @@ struct Result
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -861,30 +861,39 @@ int run(Options &options, bool host_problem_shapes_available = true)
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 ||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
)
) {
std::cerr << "This example requires CUDA 12.8 or newer.\n";
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 12 && props.minor == 0)) {
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120a).\n";
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120 or 121).\n";
return 0;
}
@ -901,7 +910,7 @@ int main(int argc, char const **args) {
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
allocate(options);
initialize(options);

View File

@ -46,7 +46,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
cutlass_example_add_executable(
79a_blackwell_geforce_nvfp4_bf16_gemm
79a_blackwell_geforce_nvfp4_bf16_gemm.cu

View File

@ -78,7 +78,7 @@
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -248,7 +248,7 @@ struct Result
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -507,25 +507,34 @@ int run(Options &options)
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 120.
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
return 0;
}
//
@ -540,9 +549,9 @@ int main(int argc, char const **args) {
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -78,7 +78,7 @@
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -183,7 +183,7 @@ cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> bloc
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_reference_SFD;
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
return cute::recast_ptr<T>(ptr);
@ -259,7 +259,7 @@ struct Result
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -531,25 +531,34 @@ int run(Options &options)
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
// and must have compute capability at least 120.
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 12 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
return 0;
}
//
@ -564,9 +573,9 @@ int main(int argc, char const **args) {
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -27,7 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
cutlass_example_add_executable(
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu

View File

@ -0,0 +1,104 @@
# Blockwise and Groupwise GEMM and Grouped GEMM on Blackwell
Blockwise and Groupwise GEMM and Grouped GEMM implement software scaling by the accumulator type.
The examples in this directory aim to demonstrate how we can instantiate this kernel and run it.
The profiler enables instantiating and profiling different kernel configurations for Blockwise and Groupwise GEMM
to determine the best performing kernel for your workload.
## Introduction
Blockwise and Groupwise GEMM operations enable fine-grained numerical precision control by applying scale factors at configurable granularities. This is particularly useful for quantized neural networks where different regions of tensors may have different scaling requirements.
For a GEMM $D = \alpha A B + \beta C$, we introduce two scale factor tensors, SFA
and SFB. This leads to a GEMM $D = \alpha \text{SFA} * A \text{ SFB} * B + \beta C$.
## Scale Factor Tensors
- *SFA*: Broadcast the same scale within a block defined by _scale granularity m_ and _scale granularity k_ when scaling A.
- Scale granularity m and scale granularity k are also referred to as _scale vector m_ and _k_ respectively.
- *SFB*: Broadcast the same scale within a block defined by _scale granularity n_ and _scale granularity k_ when scaling B.
- Scale granularity n and scale granularity k are also referred to as _scale vector n_ and _k_ respectively.
These can be represented in CuTe as:
- *SFA Layout*: $((\text{scale granularity M}, M / \text{scale granularity M}), (\text{scale granularity K}, K / \text{scale granularity K})) : ((0, int), (0, int))$
- *SFB Layout*: $((\text{scale granularity N}, M / \text{scale granularity M}), (\text{scale granularity K}, K / \text{scale granularity K})) : ((0, int), (0, int))$
The 0 element stride ensures the same group of coordinates to map to the same element in the scale factors.
## Configuration
For convenience the Blockwise and Groupwise implementation provide
`cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>`
to deduce layouts and manage compact tensors.
`cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>` by default makes
every tensor major the M/N mode, but can be configured. For example:
`cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, UMMA::Major::K, UMMA::Major::MN>`
denotes SFA will be major in the K dimension but SFB will be major in the N dimension.
## Integration with Other Frameworks
If translating from frameworks like Torch where SFA has shape
(M / ScaleGranularityM, K / ScaleGranularityK) and SFB has a shape (K / ScaleGranularityK, N / ScaleGranularityN),
ensure to transpose SFB and B to fit into the canonical CuTe layout form. This ensures K is always the second mode.
Use strides can be used to determine if each tensor is MN or K major to correctly form the layouts either directly
or with the convenience wrappers.
## Kernel Selection and Profiling
To determine the most performance Blockwise/Groupwise GEMM or Grouped GEMM kernel for your use case, you can utilize the
[CUTLASS profiler](../../media/docs/cpp/profiler.md).
All Blockwise/Groupwise GEMMs and Group GEMMs with `f32` scaling of `e4m3` or runtime `f8` types can be selected by
selecting a subset of kernels when configuring with CMake by passing:
`-DCUTLASS_LIBRARY_KERNELS="cutlass3x*f32xe4m3_*f32xe4m3*,cutlass3x*f32xf8_*f32xf8*"`.
The simplest way to use the profiler is to pass `m`, `n`, and `k` as well as your `scale_vec_size_m`,
`scale_vec_size_n`, and `scale_vec_size_k`. Passing `enable-best-kernel-for-fixed-shape` will do some autotuning
per kernel to determine best rasterization orders, swizzles, and cluster sizes. Passing `blockwiseGemm`
or `GroupedGemm` through the operation flag will determine which set of operations will be profiled.
For examle, this command using the cutlass profiler will dump the performance of all compiled kernels which support scale
granularity m = 1, scale granularity n = 128, and scale granularity k = 128 for the problem size 8192x8192x8192:
```
cutlass_profiler --operation=blockwiseGemm \
--enable-best-kernel-for-fixed-shape \
--m=8192 --n=8192 --k=8192 \
--scale_vec_size_m=1 --scale_vec_size_n=128 --scale_vec_size_k=128 \
--verification-enabled=false
```
### Kernel Naming Convention
The naming of the blockwise and groupwise kernels includes the following new pattern: for each tensor scalar pair we have
`<scale_granularity_m or scale_granularity_n>x<scale_granularity_k><accumulator type>x<scaled tensor type>`. For example
`cutlass3x_sm100_tensorop_gemm_64x128f32xe4m3_1x128f32xe4m3_f32_f16_f16_64x128x128_1x1x1_0_nnn_align16_1sm` would denote:
- A CUTLASS 3 GEMM for SM100 that uses tensor cores.
- SFA is f32 with a 64 element scale granularity m and a 128 element scale granularity k.
- The A matrix is e4m3.
- SFB is f32 with a 1 element scale granularity n and a 128 element scale granularity k.
- The B matrix is e4m3.
- The epilogue is done in f32.
- The C matrix is f16.
- The D matrix is f16.
- The MMA tile shape is 64x128x128.
- The cluster shape is 1x1x1.
- A, B, C, and D are all column major.
- The alignment of the major modes are 16 elements for A, B, C, and D.
- The MMA variant is a 1SM instruction.
It is also worthwhile to note that C can be void if scaling by beta is not needed.
## Performance Tips and Tricks
- *MMA Dimensions*: in both Blackwell and Hopper tensor cores it is worthwhile to note that the smallest `MMA_M` dimension is 64, but `MMA_N`
dimension can be as small as 8 for some instructions. For problem sizes where M is small consider computing $D^T = \alpha B^T A^T + \beta C^T$ instead.
- When computing after swapping A and B and transposing the N dimension is now our small dimension. With a small `MMA_N` we can more effectively tile without performing unecessary computation.
- *Layout Swapping*: When optimizing with the profiler swap `m` and `n` inputs and adjust layouts to reflect this swapping and transposing.
- For example if we have a row-major A, column-major B, and row-major D, we can swap tensors and run a kernel with:
- The left hand matrix as row-major (since B transposed is row-major)
- A right hand matrix as column-major (since A transposed is column-major)
- A column-major output (since D transposed is column-major).
When using blockwise and groupwise GEMM we must swap the scale vector sizes when doing this optimization. If we have a 1 element scale granularity M
and a 128 element scale granularity N, we must run a kernel with a 128 element scale granularity M and a 1 element scale granularity
N.

View File

@ -26,7 +26,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
82_blackwell_distributed_gemm
82_blackwell_distributed_gemm.cu
)
endif()

View File

@ -0,0 +1,497 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass/detail/collective/mixed_input_utils.hpp"
#include "helper.h"
#include "mixed_dtype_helper.cuh"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
using MmaType = cutlass::bfloat16_t;
using QuantType = cutlass::int4b_t;
using AccumulatorType = float;
// A matrix configuration
using ElementA = MmaType; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using ElementZero = MmaType;
using ElementScale = MmaType;
// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for C and D matrix operands
using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = AccumulatorType; // Element type for internal accumulation
using ElementCompute = AccumulatorType; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using MmaTileShape = Shape<_256,_128,_128>; // (MmaTileShape_N, MmaTileShape_M, MmaTileShape_K) as A and B will be swapped
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmMixedInputSm100; // Kernel to launch based on the default setting in the Collective Builder
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
constexpr int ScaleGranularityN = 1; //Should be less than or equal to GEMM_N
constexpr int ScaleGranularityK = 128; //Should be less than or equal to GEMM_K
using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig<ScaleGranularityN, ScaleGranularityK>;
using LayoutScale = decltype(ScaleConfig::deduce_layout_scale()); // Layout type for SFA matrix operand
LayoutScale layout_S;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaTileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C matrix.
// We can enable this if beta == 0 by changing ElementC to void below.
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
>::CollectiveOp;
// ============================================================ MIXED INPUT NO SCALES ============================================================================
//The collective will infer that the narrow type should be upcasted to the wide type.
//We swap A and B operands to the builder here
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB>, LayoutB_Transpose, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
MainloopSchedule
>::CollectiveOp;
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopConvertOnly,
CollectiveEpilogue
>;
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale>, cute::tuple<LayoutB_Transpose, LayoutScale>, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
MainloopSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ==================================================================
// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format.
using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, ElementScale, ElementZero>, cute::tuple<LayoutB_Transpose, LayoutScale>, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
MainloopSchedule
>::CollectiveOp;
using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleWithZeroPoint,
CollectiveEpilogue
>;
using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleWithZeroPoint>;
// =================================================================================================================================================================
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideC = typename GemmKernelScaleOnly::StrideC;
using StrideD = typename GemmKernelScaleOnly::StrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideC_ref stride_C_ref;
StrideD stride_D;
StrideD_ref stride_D_ref;
uint64_t seed;
// Scale and Zero share a stride since the layout and shapes must be the same.
using StrideS = typename cute::Stride<cute::Int<1>, int64_t, int64_t>;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
StrideS_ref stride_S_ref;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<MmaType> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(MixedDtypeOptions const& options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
// Reverse stride here due to swap and transpose
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
layout_S = ScaleConfig::tile_atom_to_shape_scale(make_shape(options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_S)));
block_A.reset(a_coord.product());
block_B.reset(b_coord.product());
block_B_dq.reset(b_coord.product());
block_C.reset(c_coord.product());
block_D.reset(c_coord.product());
block_ref_D.reset(c_coord.product());
block_scale.reset(blockscale_b_coord.product());
block_zero.reset(blockscale_b_coord.product());
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale<QuantType, ElementScale>(block_scale, options);
initialize_zero(block_zero, options);
if(options.verify){
auto layout_B = make_layout(shape_b, stride_B);
auto scale_stride = layout_S.stride();
auto layout_scale_zero = make_layout(
make_shape(size<0>(layout_S), size<1,1>(layout_S), size<2>(layout_S)),
make_stride(size<0,1>(scale_stride), size<1,1>(scale_stride), size<2>(scale_stride))
); //layout = (options.n, scale_k, options.l) : (_1, options.n, _0)
cudaStream_t stream = cudaStreamDefault;
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, ScaleGranularityK, stream);
}
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <class Args, cutlass::detail::ConversionMode KernelConversionMode>
Args args_from_options(MixedDtypeOptions const& options)
{
// Swap the A and B tensors, as well as problem shapes here.
if constexpr (KernelConversionMode == cutlass::detail::ConversionMode::DirectConvert) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B.get(), stride_B, block_A.get(), stride_A},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
else if constexpr(KernelConversionMode == cutlass::detail::ConversionMode::ConvertAndScale) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), layout_S},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
else if constexpr(KernelConversionMode == cutlass::detail::ConversionMode::ConvertAndScaleWithZero) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), layout_S, block_zero.get()},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
} else {
exit(-1);
}
}
bool verify(MixedDtypeOptions const& options) {
//
// Compute reference output
//
constexpr int AlignmentBdq = 128 / cutlass::sizeof_bits<MmaType>::value;
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentBdq,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, cutlass::arch::OpClassTensorOp,
MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueSchedule
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// compare_reference
ElementD const epsilon(1e-2f);
ElementD const non_zero_floor(1e-2f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(MixedDtypeOptions &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<typename Gemm::Arguments, Gemm::CollectiveMainloop::KernelConversionMode>(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
MixedDtypeResult result;
if(options.verify){
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
else{
result.passed = true;
std::cout << " Verification: Off " << std::endl;
}
if (!result.passed) {
exit(-1);
}
mixed_dtype_profiling(gemm, options, result);
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 100a.
bool is_correct_cuda_version = (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 8);
if (!is_correct_cuda_version) {
std::cerr << "Version is " << __CUDACC_VER_MINOR__ << "\n";
std::cerr << "This example requires CUDA 12.8 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture or "
<< "later (compute capability 100a or greater).\n";
return 0;
}
//
// Parse options
//
MixedDtypeOptions options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
std::cout << "Running in conversion only mode." << std::endl;
run<GemmConvertOnly>(options);
}
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
std::cout << "Running in scale mode." << std::endl;
run<GemmScaleOnly>(options);
}
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
std::cout << "Running in scale and zero mode." << std::endl;
run<GemmScaleWithZeroPoint>(options);
}
else{
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
}
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,45 @@
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(TEST_S_TILE_SHAPE --m=256 --n=128 --k=32 --verify --iterations=0)
set(TEST_S_TILE_SHAPE_MULTIPLE_KITER --m=256 --n=128 --k=128 --verify --iterations=0)
set(TEST_S_DIFFERENT_MN --m=16384 --n=4608 --k=4608 --verify --iterations=0)
set(TEST_S_ONE_WAVE --m=1536 --n=1536 --k=32 --verify --iterations=0) # Assuming 144 SMs
set(TEST_S_2048 --m=2048 --n=2048 --k=2048 --verify --iterations=0) # Multi-wave
if(NOT WIN32)
cutlass_example_add_executable(
86_blackwell_mixed_dtype_gemm
86_blackwell_mixed_dtype.cu
TEST_COMMAND_OPTIONS
TEST_S_TILE_SHAPE
TEST_S_TILE_SHAPE_MULTIPLE_KITER
TEST_S_ONE_WAVE
TEST_S_2048
)
endif()

View File

@ -0,0 +1,269 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cute/tensor.hpp"
#include <cuda.h>
#include <numeric>
#include "helper.h"
enum MixedDtypeGemmMode {
ConvertOnly,
ScaleOnly,
ScaleWithZeroPoint
};
/// Command line options parsing
struct MixedDtypeOptions {
bool help = false;
bool verify = false;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 1000;
int warmup = 1000;
int mode = 1;
int m = 5120, n = 4096, k = 4096;
int l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("verify")) {
verify = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("mode", mode);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("warmup", warmup);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "86_blackwell_mixed_dtype_gemm\n\n"
<< " Blackwell Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --warmup=<int> Number of warmup iterations to perform.\n\n"
<< " --verify=<int> Run verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "86_blackwell_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct MixedDtypeResult
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;
};
/// Profiling Loop
template <class Gemm>
void mixed_dtype_profiling(
Gemm& gemm,
MixedDtypeOptions const& options,
MixedDtypeResult& result) {
if (options.iterations <= 0) return;
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
std::vector<float> runtimes;
runtimes.reserve(options.iterations);
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
cudaEventRecord(start);
CUTLASS_CHECK(gemm.run());
cudaEventRecord(stop);
cudaEventSynchronize(stop);
if (iter >= options.warmup) {
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
runtimes.push_back(milliseconds);
}
}
cudaEventDestroy(start);
cudaEventDestroy(stop);
// Compute average setup and runtime and GFLOPs.
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
/// Helpers to initialize a block of device data
template <class Element>
bool initialize_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed = 2023) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <typename Element>
bool initialize_quant_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed = 2023) {
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <class QuantType, class Element>
bool initialize_scale(
cutlass::DeviceAllocation<Element>& block,
MixedDtypeOptions const& options,
uint64_t seed = 2023) {
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
std::vector<Element> stage(block.size(), Element(1.0f));
block.copy_from_host(stage.data());
}
else {
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
const float max_dequant_val = 4.f;
const float min_dequant_val = 0.5f;
float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
}
return true;
}
template <class Element>
bool initialize_zero(
cutlass::DeviceAllocation<Element>& block,
MixedDtypeOptions const& options,
uint64_t seed = 2023) {
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
} else {
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
std::vector<Element> stage(block.size(), Element(0.0f));
block.copy_from_host(stage.data());
}
return true;
}

View File

@ -0,0 +1,518 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS.
This example demonstrates a simple way to instantiate and run a blockwise scaling FP8 GEMM on the NVIDIA Blackwell SM120 architecture.
This kernel is optimized for the GeForce RTX 50 series GPUs.
This kernel accepts Inputs A and B with TileMxTileK and TileNxTileK FP32 block scaling, performing scaling and accumulation every TileK elements.
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
3. Runtime scaling block size.
Usage:
$ ./examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
#include "./utils.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile
using MmaTileShape_MNK = Shape<_128,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
using ScaleConfig = decltype(cutlass::detail::sm120_trivial_blockwise_scale_config(MmaTileShape_MNK{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutC, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
// Strides just iterate over scalars and have no zeros
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
// Layouts are tiled to the problem size and the strides have zeros
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "87a_blackwell_geforce_gemm_blockwise\n\n"
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "87a_blackwell_geforce_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
tensor_SFA.resize(blockscale_a_coord);
tensor_SFB.resize(blockscale_b_coord);
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A,
tensor_B.device_data(), stride_B,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(Options const& options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <class Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,539 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS.
This example demonstrates a simple way to instantiate and run cooperative and ping-pong groupwise scaling FP8 GEMMs on the NVIDIA Blackwell SM120 architecture.
These kernels are optimized for GeForce RTX 50 series GPUs.
The blockscaling kernels accept Inputs A and B with 1xTileK and TileNxTileK FP32 block scaling, performing scaling and accumulation every TileK elements.
The ping-pong kernel leverages a smaller tile shape to avoid register spilling for better performance.
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
3. Epilogue Optimization
Note that GeForce RTX 50 series GPUs do not support:
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
2. Dynamic datatypes.
3. Runtime scaling block size.
Usage:
$ ./examples/87_blackwell_geforce_gemm_blockwise/87b_blackwell_geforce_fp8_bf16_gemm_groupwise --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
#include "./utils.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile
using CooperativeMmaTileShape_MNK = Shape<_128,_128,_128>;
// Smaller tile size for pingpong schedule to avoid register spilling
using PingpongMmaTileShape_MNK = Shape<_64, _128, _128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
template <class TileShape>
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutC, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
template <class TileShape, class Schedule>
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue<TileShape>::SharedStorage))>,
Schedule // cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120
>::CollectiveOp;
template <class TileShape, class Schedule>
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop<TileShape, Schedule>,
CollectiveEpilogue<TileShape>,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
// We are using cooperative kernel schedule by default
using CooperativeGemm = cutlass::gemm::device::GemmUniversalAdapter<
GemmKernel<CooperativeMmaTileShape_MNK, cutlass::gemm::KernelScheduleSm120Blockwise>>;
// Pingpong kernel
using PingpongGemm = cutlass::gemm::device::GemmUniversalAdapter<
GemmKernel<PingpongMmaTileShape_MNK, cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120>>;
using StrideA = typename CooperativeGemm::GemmKernel::StrideA;
using StrideB = typename CooperativeGemm::GemmKernel::StrideB;
using StrideC = typename CooperativeGemm::GemmKernel::StrideC;
using StrideD = typename CooperativeGemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
// Strides just iterate over scalars and have no zeros
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
// Layouts are tiled to the problem size and the strides have zeros
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "87b_blackwell_geforce_gemm_groupwise\n\n"
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "87b_blackwell_geforce_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
tensor_SFA.resize(blockscale_a_coord);
tensor_SFB.resize(blockscale_b_coord);
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <class Gemm>
typename Gemm::Arguments args_from_options(const Options &options) {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A,
tensor_B.device_data(), stride_B,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <class Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<Gemm>(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
printf("Running kernel with Cooperative kernel schedule:\n");
run<CooperativeGemm>(options);
printf("Running kernel with Pingpong kernel schedule:\n");
run<PingpongGemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,678 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief An FP8 groupwise scaled grouped GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS.
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM120 TensorOp-based warp-specialized kernel
for FP8 with per-group:1x128x128 FP32 scaling factors.
In this example, M, N, and K are fixed across groups.
As RTX 50 series GPUs do not support runtime scaling block sizes, all groups share the same block scaling size.
For this example all scheduling work is performed on the device, utilizing the device-side modification of TMA descriptors
to move between groups/problem_count (represented by groups).
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
To run this example:
$ ./examples/87_blackwell_geforce_gemm_blockwise/87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise --m=2048 --n=2048 --k=2048 --groups=10
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Same applies for alpha and beta values that are randomized across the different groups.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
#include "./utils.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile
using MmaTileShape_MNK = Shape<_128,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
// Scaling Factors
using ElementSF = ElementAccumulator;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC *, AlignmentC,
ElementD, LayoutD *, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelScheduleSm120Blockwise
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA, LayoutSFA>);
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB, LayoutSFB>);
/// Initialization
uint64_t seed;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
using HostTensorA = cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout>;
using HostTensorB = cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout>;
using HostTensorC = cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout>;
using HostTensorD = cutlass::HostTensor<Gemm::EpilogueOutputOp::ElementOutput, cutlass::layout::PackedVectorLayout>;
using HostTensorSFA = cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout>;
using HostTensorSFB = cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout>;
std::vector<HostTensorA> block_A;
std::vector<HostTensorB> block_B;
std::vector<HostTensorC> block_C;
std::vector<HostTensorD> block_D;
std::vector<HostTensorD> block_ref_D;
std::vector<HostTensorSFA> block_SFA;
std::vector<HostTensorSFB> block_SFB;
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
cutlass::DeviceAllocation<ElementA const*> ptr_A;
cutlass::DeviceAllocation<ElementB const*> ptr_B;
cutlass::DeviceAllocation<ElementSF const*> ptr_SFA;
cutlass::DeviceAllocation<ElementSF const*> ptr_SFB;
cutlass::DeviceAllocation<ElementC const*> ptr_C;
cutlass::DeviceAllocation<ElementD *> ptr_D;
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1, groups = 10;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
RasterOrderOptions raster_order = RasterOrderOptions::AlongN;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char, 'N');
if (raster_char == 'N' || raster_char == 'n') {
raster_order = RasterOrderOptions::AlongN;
} else if (raster_char == 'M' || raster_char == 'm') {
raster_order = RasterOrderOptions::AlongM;
}
for (int i = 0; i < groups; ++i) {
problem_sizes_host.push_back({m, n, k});
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "87c_blackwell_geforce_grouped_gemm_groupwise\n\n"
<< " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "87c_blackwell_geforce_grouped_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --groups=8 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * groups;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementSF *> ptr_SFA_host(options.groups);
std::vector<ElementSF *> ptr_SFB_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementD *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
for (int i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto [M, N, K] = problem;
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
auto stride_D = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
stride_A_host.push_back(stride_A);
stride_B_host.push_back(stride_B);
layout_SFA_host.push_back(layout_SFA);
layout_SFB_host.push_back(layout_SFB);
stride_C_host.push_back(stride_C);
stride_D_host.push_back(stride_D);
block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A))));
block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B))));
block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C))));
block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
block_SFA.push_back(HostTensorSFA(cutlass::make_Coord(size(filter_zeros(layout_SFA)))));
block_SFB.push_back(HostTensorSFB(cutlass::make_Coord(size(filter_zeros(layout_SFB)))));
block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
}
for (int i = 0; i < options.groups; ++i) {
initialize_tensor(block_A.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(block_B.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(block_C.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_tensor(block_SFA.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_tensor(block_SFB.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2026);
block_A.at(i).sync_device();
block_B.at(i).sync_device();
block_C.at(i).sync_device();
block_SFA.at(i).sync_device();
block_SFB.at(i).sync_device();
ptr_A_host.at(i) = block_A.at(i).device_data();
ptr_B_host.at(i) = block_B.at(i).device_data();
ptr_C_host.at(i) = block_C.at(i).device_data();
ptr_D_host.at(i) = block_D.at(i).device_data();
ptr_SFA_host.at(i) = block_SFA.at(i).device_data();
ptr_SFB_host.at(i) = block_SFB.at(i).device_data();
alpha_host.push_back((options.alpha == std::numeric_limits<float>::max()) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == std::numeric_limits<float>::max()) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_SFA.reset(options.groups);
ptr_SFA.copy_from_host(ptr_SFA_host.data());
ptr_SFB.reset(options.groups);
ptr_SFB.copy_from_host(ptr_SFB_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
layout_SFA.reset(options.groups);
layout_SFA.copy_from_host(layout_SFA_host.data());
layout_SFB.reset(options.groups);
layout_SFB.copy_from_host(layout_SFB_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = options.raster_order;
typename Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
if (options.alpha != std::numeric_limits<float>::max()) {
fusion_args.alpha = options.alpha;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.dAlpha = {_0{}, _0{}, 0};
} else {
fusion_args.alpha = 0;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.dAlpha = {_0{}, _0{}, 1};
}
if (options.beta != std::numeric_limits<float>::max()) {
fusion_args.beta = options.beta;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dBeta = {_0{}, _0{}, 0};
} else {
fusion_args.beta = 0;
fusion_args.beta_ptr_array = beta_device.get();
fusion_args.dBeta = {_0{}, _0{}, 1};
}
arguments = {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(),
ptr_B.get(), stride_B.get(),
ptr_SFA.get(), layout_SFA.get(),
ptr_SFB.get(), layout_SFB.get()},
{
fusion_args,
ptr_C.get(), stride_C.get(),
ptr_D.get(), stride_D.get()
},
hw_info, scheduler
};
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
bool passed = true;
for (int i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto [M, N, K] = problem;
auto A = cute::make_tensor(block_A.at(i).host_data(),
cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i)));
auto B = cute::make_tensor(block_B.at(i).host_data(),
cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i)));
auto C = cute::make_tensor(block_C.at(i).host_data(),
cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i)));
auto D = cute::make_tensor(block_ref_D.at(i).host_data(),
cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i)));
auto SFA = cute::make_tensor(block_SFA.at(i).host_data(), layout_SFA_host.at(i));
auto SFB = cute::make_tensor(block_SFB.at(i).host_data(), layout_SFB_host.at(i));
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = alpha_host.at(i);
epilogue_params.beta = beta_host.at(i);
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
block_D.at(i).sync_host();
passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view());
}
return passed;
}
/// Execute a given example GEMM computation
template <class Gemm>
int run(Options &options) {
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << " " << options.groups << " Groups" << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
// or CUDA 12.9 or higher for SM121 support.
// Must have compute capability at least 120.
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#endif
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
initialize(options);
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,47 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
cutlass_example_add_executable(
87a_blackwell_geforce_fp8_bf16_gemm_blockwise
87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu
)
cutlass_example_add_executable(
87b_blackwell_geforce_fp8_bf16_gemm_groupwise
87b_blackwell_geforce_fp8_bf16_gemm_groupwise.cu
)
cutlass_example_add_executable(
87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise
87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise.cu
)
endif()

View File

@ -0,0 +1,83 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/// Helper to initialize a block of device data
template <class Element, class Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_input == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}

View File

@ -137,6 +137,9 @@ struct FmhaKernelTma {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
TileScheduler tile_scheduler{params.tile_scheduler};
// Shared memory.
@ -216,6 +219,7 @@ struct FmhaKernelTma {
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.epilogue);
#endif
}
};

View File

@ -160,6 +160,9 @@ struct FmhaKernelTmaWarpSpecialized {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED)
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
#else
enum class WarpGroupRole {
Producer = 0,
@ -412,6 +415,7 @@ struct FmhaKernelTmaWarpSpecialized {
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
}
}
#endif
}
};

View File

@ -0,0 +1,545 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM103 architecture.
This example demonstrates a simple way to instantiate and run a blockscaled 3xFP4 GEMM on the NVIDIA Blackwell SM103 architecture.
Usage:
$ ./examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm --m=2048 --n=2048 --k=2048
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e2m1_t; // Element type for A matrix operand
using ElementSFA = cutlass::float_ue4m3_t;
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e2m1_t; // Element type for A matrix operand
using ElementSFB = cutlass::float_ue4m3_t;
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm103; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
// Kernel Perf config
using MmaTileShape = cute::Shape<cute::_128, cute::_128, Int<768>>; // MMA's tile size
using ClusterShape = cute::Shape<cute::_2, cute::_4, cute::_1>; // Shape of the threadblocks in a cluster
// Epilogue fusion operator
using EpilogueFusionOp = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized1Sm,
EpilogueFusionOp
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementA,ElementSFA>, LayoutATag, AlignmentA,
cute::tuple<ElementB,ElementSFB>, LayoutBTag, AlignmentB,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 // Kernel schedule policy. Auto or using targeted scheduling policy
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
uint64_t seed;
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementSFA, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementSFB, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
// Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
// Reference Output Tensor
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
template <typename T>
auto make_iterator(T* ptr) {
return cute::recast_ptr<T>(ptr);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k;
int swizzle = 0;
Options():
help(false),
m(1024), n(1024), k(1024),
alpha(1.f), beta(0.f),
iterations(10),
swizzle(0)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("swizzle", swizzle);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "89_sm103_fp4_ultra_gemm\n\n"
<< " Sm103 3xFP4 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --swizzle=<int> Cluster rasterization swizzle\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ " << "./examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
}
// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, 1},
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{options.alpha, options.beta},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
}
};
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
bool verify(const Options &options) {
using namespace cute;
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementAccumulator, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Comparison
block_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
uint8_t* workspace = nullptr;
cudaError_t status = cudaMalloc(&workspace, workspace_size);
if (status != cudaSuccess) {
std::cerr << "Failed to allocate workspace memory: " << cudaGetErrorString(status) << std::endl;
return -1;
}
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Free workspace memory
cudaFree(workspace);
cudaDeviceSynchronize();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.9 or higher Toolkit to run this example
// and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
std::cerr << "This example requires CUDA 12.9 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 3)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 103)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,38 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 103a)
cutlass_example_add_executable(
89_sm103_fp4_ultra_gemm
89_sm103_fp4_ultra_gemm.cu
)
endif()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --beta=2.0 --k=512 --groups=51 --iterations=0)
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --beta=0.5 --groups=50 --iterations=0) # Small problem sizes
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_SMALL_GROUP --groups=3 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_SMALL_GROUP --alpha=1.5 --beta=2.0 --groups=3 --iterations=1) # Random problem sizes
if (CUTLASS_NVCC_ARCHS MATCHES 103a)
cutlass_example_add_executable(
90_sm103_fp4_ultra_grouped_gemm
90_sm103_fp4_ultra_grouped_gemm.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_SMALL_GROUP
TEST_EPILOGUE_SMALL_GROUP
)
endif()

View File

@ -0,0 +1,898 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdint> // uint64_t
#include <cstdio>
#include <cstdlib> // rand(), RAND_MAX
#include <string> // std::stoi
#include <vector>
#include <iostream>
#include <float.h>
#include <optional>
#include "cutlass/util/command_line.h"
// clang-format off
#include "cute/tensor.hpp" // FIX cute header file inclusion issue
// clang-format on
#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::Major
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
#include "cutlass/complex.h" // cutlass::ComplexTransform
#include "cutlass/cutlass.h" // cutlass::Status
#include "cutlass/detail/sm100_blockscaled_layout.hpp" // cutlass::detail::Sm1xxBlockScaledOutputConfig
#include "cutlass/epilogue/thread/linear_combination.h" // cutlass::epilogue::thread::LinearCombination
#include "cutlass/gemm/device/gemv_blockscaled.h" // cutlass::gemm::device::Gemv
#include "cutlass/gemm/kernel/gemv_blockscaled.h" // cutlass::gemm::kernel::Gemv
#include "cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h" // cutlass::epilogue::threadblock::GemvEpilogueWithScalingFactor
#include "cutlass/gemm_coord.h" // cutlass::GemmCoord
#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory
#include "cutlass/numeric_size.h" // cutlss::is_subbyte
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h" // cutlass::is_same_v
#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation
#include "cutlass/util/distribution.h" // cutlass::Distribution
#include "cutlass/util/host_tensor.h" // cutlass::HostTensor
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
#include "cutlass/util/reference/host/gemm_complex.h" // cutlass::reference::host::GemmComplex
#include <cutlass/util/reference/host/gett.hpp> // cutlass::reference::host::GettBlockScalingMainloopParams
// cutlass::reference::host::GettBlockScalingEpilogueParams
// cutlass::reference::host::Gemm3x
#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals
#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
// Helper Functions
template <typename T>
auto
make_iterator(T* ptr)
{
return cute::recast_ptr<T>(ptr);
}
template <typename Element, typename Layout>
bool
initialize_tensor(cutlass::TensorView<Element, Layout> view, cutlass::Distribution::Kind dist_kind, uint64_t seed)
{
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
} else if (bits_input <= 8) {
if constexpr (cutlass::is_same_v<Element, cutlass::float_ue4m3_t> ||
cutlass::is_same_v<Element, cutlass::float_ue8m0_t>) {
scope_max = 4;
scope_min = 1;
} else {
scope_max = 1;
scope_min = -1;
}
} else {
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else {
CUTLASS_ASSERT(false);
return false;
}
return true;
}
// Base class of Testbed
template <
typename Gemv_,
// The following types are more difficult to be derived from EVT
typename ElementC, typename LayoutC, typename ElementD_,
typename LayoutD, typename ElementSFD_, typename LayoutSFD,
typename ElementCompute_, int kVectorSize_>
struct TestbedGemvFp4SFDBase
{
public:
using Gemv = Gemv_;
using ElementA = typename Gemv::ElementA;
using ElementSFA = typename Gemv::ElementSFA;
using LayoutA = typename Gemv::LayoutA;
static_assert(cutlass::is_same_v<LayoutA, cutlass::layout::RowMajor>, "only support row major matrix A");
static_assert(cutlass::sizeof_bits<ElementSFA>::value == 8, "ElementSFA should be FP8 type");
using ElementB = typename Gemv::ElementB;
using ElementSFB = typename Gemv::ElementSFB;
using LayoutB = cutlass::layout::ColumnMajor;
static_assert(cutlass::is_same_v<ElementA, ElementB>, "only support ElementA ElementB of same type");
static_assert(cutlass::sizeof_bits<ElementSFB>::value == 8, "ElementSFB should be FP8 type");
static_assert(cutlass::is_same_v<LayoutC, cutlass::layout::ColumnMajor>, "only support col major output D");
using ElementD = ElementD_;
static_assert(cutlass::is_same_v<LayoutD, cutlass::layout::ColumnMajor>, "only support col major output D");
using ElementSFD = ElementSFD_;
static_assert(cutlass::is_same_v<LayoutSFD, cutlass::layout::ColumnMajor>, "only support col major output SFD");
static_assert(cutlass::sizeof_bits<ElementSFD>::value, "only support 8 bit SFD");
using ElementAccumulator = typename Gemv::ElementAccumulator;
using ElementCompute = ElementCompute_;
static_assert(cutlass::is_same_v<ElementCompute, float>, "only support fp32 epi compute");
static constexpr int kVectorSize = kVectorSize_;
static_assert(kVectorSize == 16, "only support vs 16");
// SFD Config
static constexpr bool kIsKMajorSFD = cutlass::is_same_v<LayoutSFD, cutlass::layout::RowMajor>;
using Sm1xxBlockScaledOutputConfig=
cutlass::detail::Sm1xxBlockScaledOutputConfig<kVectorSize,
kIsKMajorSFD ? cute::UMMA::Major::K : cute::UMMA::Major::MN>;
using Blk_MN_Output = typename Sm1xxBlockScaledOutputConfig::Blk_MN;
using Blk_SF_Output = typename Sm1xxBlockScaledOutputConfig::Blk_SF;
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
// SFA SFB Config
using Sm100BlockScaledInputConfig = cutlass::detail::Sm1xxBlockScaledConfig<kVectorSize>;
using Blk_MN_Input = typename Sm100BlockScaledInputConfig::Blk_MN;
using Blk_SF_Input = typename Sm100BlockScaledInputConfig::Blk_SF;
using SfAtom_Input = typename Sm100BlockScaledInputConfig::SfAtom;
public:
TestbedGemvFp4SFDBase(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_SFA_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_SFB_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_SFD_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2023)
: init_A(init_A_)
, init_B(init_B_)
, init_C(init_C_)
, init_D(init_D_)
, init_SFA(init_SFA_)
, init_SFB(init_SFB_)
, init_SFD(init_SFD_)
, seed(seed_)
{
}
bool initialize(cutlass::MatrixCoord problem_size, int32_t batch_count)
{
const int32_t gemm_m = problem_size.row();
const int32_t gemm_k = problem_size.column();
const int32_t gemm_n = 1;
const int32_t gemm_batch = batch_count;
// Resize Config SFA/SFB
auto k_blks_input = cutlass::ceil_div(gemm_k, cute::size<1>(shape(SfAtom_Input{})));
auto m_blks_input = cutlass::ceil_div(gemm_m, Blk_MN_Input{});
auto n_blks_input = cutlass::ceil_div(gemm_n, Blk_MN_Input{});
auto sfa_coord = cutlass::make_Coord(m_blks_input * Blk_MN_Input{} * gemm_batch, k_blks_input * Blk_SF_Input{});
auto sfb_coord = cutlass::make_Coord(n_blks_input * Blk_MN_Input{} * gemm_batch, k_blks_input * Blk_SF_Input{});
auto sfa_resize_layout =
cutlass::layout::Affine2Layout_Factory<LayoutA>::layout_factory(sfa_coord, typename LayoutA::Stride{});
auto sfb_resize_layout =
cutlass::layout::Affine2Layout_Factory<LayoutB>::layout_factory(sfb_coord, typename LayoutB::Stride{});
// Use the same SFD layout generation as reference for tensor creation
using ProblemShapeType = cute::Shape<int, int, int, int>;
auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch};
// Generate the same layout as reference uses
auto sfd_layout = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL);
// Extract size from the generated layout and create coordinate
auto sfd_size = cute::size(cute::filter_zeros(sfd_layout));
auto sfd_coord = cutlass::make_Coord(sfd_size, 1); // Linear layout for HostTensor
auto sfd_resize_layout =
cutlass::layout::Affine2Layout_Factory<LayoutSFD>::layout_factory(sfd_coord, typename LayoutSFD::Stride{});
// Resize Host
this->reference_D.resize({gemm_batch * gemm_m, 1}); // D col major vector
this->reference_SFD.resize(sfd_coord, sfd_resize_layout);
if (initialize_tensor(this->reference_D.host_view(), this->init_D, this->seed + 7) == false) {
printf("initialize_tensor() REF D failed\n");
return false;
}
if (initialize_tensor(this->reference_SFD.host_view(), this->init_SFD, this->seed + 9) == false) {
printf("initialize_tensor() REF SFD failed\n");
return false;
}
// Resize A/B/C/D
this->tensor_A.resize({gemm_batch * gemm_m, gemm_k}); // A row major
this->tensor_B.resize({gemm_batch * gemm_k, 1}); // B col major vector
this->tensor_C.resize({gemm_batch * gemm_m, 1}); // C col major vector
this->tensor_D.resize({gemm_batch * gemm_m, 1}); // D col major vector
this->tensor_SFA.resize(sfa_coord, sfa_resize_layout);
this->tensor_SFB.resize(sfb_coord, sfb_resize_layout);
this->tensor_SFD.resize(sfd_coord, sfd_resize_layout);
// Fill A/B/C
if (initialize_tensor(this->tensor_A.host_view(), this->init_A, this->seed + 1) == false) {
printf("initialize_tensor() A failed\n");
return false;
}
if (initialize_tensor(this->tensor_B.host_view(), this->init_B, this->seed + 2) == false) {
printf("initialize_tensor() B failed\n");
return false;
}
if (initialize_tensor(this->tensor_C.host_view(), this->init_C, this->seed + 3) == false) {
printf("initialize_tensor() C failed\n");
return false;
}
// Fill SFA/SFB
if (initialize_tensor(this->tensor_SFA.host_view(), this->init_SFA, this->seed + 4) == false) {
printf("initialize_tensor() SFA failed\n");
return false;
}
if (initialize_tensor(this->tensor_SFB.host_view(), this->init_SFB, this->seed + 5) == false) {
printf("initialize_tensor() SFB failed\n");
return false;
}
// Fill D/SFD
if (initialize_tensor(this->tensor_D.host_view(), this->init_D, this->seed + 6) == false) {
printf("initialize_tensor() D failed\n");
return false;
}
if (initialize_tensor(this->tensor_SFD.host_view(), this->init_SFD, this->seed + 8) == false) {
printf("initialize_tensor() SFD failed\n");
return false;
}
// Copy A/B/C from host to device
this->tensor_A.sync_device();
this->tensor_B.sync_device();
this->tensor_C.sync_device();
this->tensor_D.sync_device();
this->tensor_SFA.sync_device();
this->tensor_SFB.sync_device();
this->tensor_SFD.sync_device();
// SFD initialization is different.
// Init referenceSFD on host first, and then copy data to tensorSFD device side.
// This ensures tensorSFD and referenceSFD to have same data,
// otherwise the "bubbles" due to SFD layouts can lead to false negative sanity check.
cutlass::device_memory::copy_to_host(this->reference_SFD.host_data(), this->tensor_SFD.device_data(), sfd_size);
return true;
}
bool compare_reference()
{
// device -> host
this->tensor_D.sync_host();
bool passed = true;
// Check
passed = cutlass::reference::host::TensorEquals(this->reference_D.host_view(), this->tensor_D.host_view());
if (passed == false) {
printf("gemm_m: %d, gemm_k: %d, ", this->tensor_A.host_view().extent(0), this->tensor_A.host_view().extent(1));
printf("tensorD mismatch\n");
return false;
}
this->tensor_SFD.sync_host();
passed = cutlass::reference::host::TensorEquals(this->reference_SFD.host_view(), this->tensor_SFD.host_view());
if (passed == false) {
printf("gemm_m: %d, gemm_k: %d, ", this->tensor_A.host_view().extent(0), this->tensor_A.host_view().extent(1));
printf("tensorSFD mismatch\n");
return false;
}
return passed;
}
bool run_reference(cutlass::MatrixCoord problem_size,
int32_t batch_count,
ElementCompute alpha,
ElementCompute beta,
float epilogue_st)
{
const int32_t gemm_m = problem_size.row();
const int32_t gemm_k = problem_size.column();
const int32_t gemm_n = 1;
const int32_t gemm_batch = batch_count;
// Run reference blockscale GETT
using ProblemShapeType = cute::Shape<int, int, int, int>;
auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch};
auto SfD = make_tensor(make_iterator(this->reference_SFD.host_data()),
Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL));
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutD>;
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(gemm_m, gemm_k, gemm_batch));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(gemm_n, gemm_k, gemm_batch));
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(gemm_m, gemm_n, gemm_batch));
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(gemm_m, gemm_n, gemm_batch));
auto A = make_tensor(make_iterator(this->tensor_A.host_data()),
cute::make_layout(cute::make_shape(gemm_m, gemm_k, gemm_batch), stride_a));
auto B = make_tensor(make_iterator(this->tensor_B.host_data()),
cute::make_layout(cute::make_shape(gemm_n, gemm_k, gemm_batch), stride_b));
auto C = cute::make_tensor(make_iterator(this->tensor_C.host_data()),
cute::make_layout(cute::make_shape(gemm_m, gemm_n, gemm_batch), stride_c));
auto D = cute::make_tensor(make_iterator(this->reference_D.host_data()),
cute::make_layout(cute::make_shape(gemm_m, gemm_n, gemm_batch), stride_d));
auto layout_sfa = Sm100BlockScaledInputConfig::tile_atom_to_shape_SFA(problem_shape_MNKL);
auto layout_sfb = Sm100BlockScaledInputConfig::tile_atom_to_shape_SFB(problem_shape_MNKL);
auto SfA = make_tensor(this->tensor_SFA.host_data(), layout_sfa);
auto SfB = make_tensor(this->tensor_SFB.host_data(), layout_sfb);
// Internally scale factor of mainloop will be disabled when ElementA/B == ElementSFA/B.
typename cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator, // ElementAccumulator
decltype(A), // TensorA
decltype(SfA), // TensorSfA
decltype(B), // TensorB
decltype(SfB) // TensorSfB
>
mainloop_params{A, SfA, B, SfB};
typename cutlass::reference::host::GettBlockScalingEpilogueParams<ElementCompute, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementCompute, // ElementCompute
decltype(C), // TensorC
decltype(D), // TensorD
decltype(SfD), // TensorSfD
cute::Int<kVectorSize>, // OutputVectorSize
cutlass::reference::host::SfStrategy::SfDGen
>
epilogue_params{alpha, beta, C, D, SfD, epilogue_st};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
return true;
}
virtual typename Gemv::Arguments get_arguments(
cutlass::MatrixCoord problem_size, int32_t batch_count,
float epilogue_st, ElementCompute alpha, ElementCompute beta) = 0;
bool run_gemv(cutlass::MatrixCoord problem_size,
int32_t batch_count,
ElementCompute alpha,
ElementCompute beta,
[[maybe_unused]] float epilogue_st,
bool is_profiling,
int kIterations)
{
// Not support batch input for testing
const int32_t gemm_m = problem_size.row();
const int32_t gemm_k = problem_size.column();
[[maybe_unused]] const int32_t gemm_n = 1;
[[maybe_unused]] const int32_t gemm_batch = batch_count;
Gemv gemv_op;
typename Gemv::Arguments arguments = this->get_arguments(
problem_size, batch_count, epilogue_st, alpha, beta
);
cutlass::Status status = gemv_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
printf("can_implement() failed\n");
return false;
}
size_t workspace_size = Gemv::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = gemv_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
printf("initialize() failed\n");
return false;
}
if (not is_profiling) {
status = gemv_op();
}
// profiling
else {
cudaError_t result;
cudaEvent_t events[2];
for (cudaEvent_t &evt : events) {
result = cudaEventCreate(&evt);
if (result != cudaSuccess) {
std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
}
// warmup
status = gemv_op();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Device execution failed on warmup." << std::endl;
return false;
}
result = cudaEventRecord(events[0]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
for (int iter_i = 0; iter_i < kIterations; ++iter_i) {
status = gemv_op();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Device execution failed." << std::endl;
return false;
}
}
result = cudaEventRecord(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
float elapsed_ms = 0;
result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
for (cudaEvent_t &evt : events) {
result = cudaEventDestroy(evt);
if (result != cudaSuccess) {
std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
}
int64_t flops = int64_t(gemm_m) * gemm_n * gemm_k * 2;
int64_t bytes = cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementA>) * int64_t(gemm_m) * int64_t(gemm_k)) +
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementB>) * int64_t(gemm_k) * int64_t(gemm_n)) +
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementD>) * int64_t(gemm_m) * int64_t(gemm_n)) +
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementSFA>) * int64_t(gemm_m) * int64_t(gemm_k) / int64_t(kVectorSize)) +
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementSFB>) * int64_t(gemm_k) * int64_t(gemm_n) / int64_t(kVectorSize)) +
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementSFD>) * int64_t(gemm_m) * int64_t(gemm_n) / int64_t(kVectorSize));
double gflops_per_second = double(flops) * kIterations * gemm_batch / double(elapsed_ms / 1000.0f) / double(1.0e9);
double gbytes_per_second = double(bytes) * kIterations * gemm_batch / double(elapsed_ms / 1000.0f) / double(1 << 30);
double elapsed_ms_per_iter = double(elapsed_ms) / kIterations;
std::cout << " Problem: "
<< gemm_m << "-by-" << gemm_n << "-by-" << gemm_k
<< ", batch size: " << gemm_batch
<< std::endl;
std::cout << " Runtime: " << elapsed_ms_per_iter << " ms" << std::endl;
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl;
}
if (status != cutlass::Status::kSuccess) {
printf("gemv exec failed\n");
return false;
}
return true;
}
bool run_and_verify(cutlass::MatrixCoord problem_size,
int32_t batch_count,
ElementCompute alpha,
ElementCompute beta,
float epilogue_st)
{
// Initialize Data
if (this->initialize(problem_size, batch_count) == false) {
return false;
}
// Run GEMV kernel
if (this->run_gemv(problem_size, batch_count, alpha, beta, epilogue_st, false /*is_profiling*/, 1) == false) {
return false;
}
// Run Reference Kernel
if (this->run_reference(problem_size, batch_count, alpha, beta, epilogue_st) == false) {
printf("run_reference() failed\n");
return false;
}
// Verify
if (this->compare_reference() == false) {
printf("compare_reference() failed\n");
return false;
}
return true;
}
bool profile(cutlass::MatrixCoord problem_size,
int32_t batch_count,
ElementCompute alpha,
ElementCompute beta,
float epilogue_st,
int kIterations = 10)
{
// Initialize Data
if (this->initialize(problem_size, batch_count) == false) {
return false;
}
// Profile GEMV kernel
if (this->run_gemv(problem_size, batch_count, alpha, beta, epilogue_st, true /*is_profiling*/, kIterations) == false) {
return false;
}
return true;
}
public:
// Data Storage
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
cutlass::HostTensor<ElementSFA, LayoutA> tensor_SFA;
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
cutlass::HostTensor<ElementSFB, LayoutB> tensor_SFB;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<ElementD, LayoutD> tensor_D;
cutlass::HostTensor<ElementSFD, LayoutD> tensor_SFD;
cutlass::HostTensor<ElementD, LayoutD> reference_D;
cutlass::HostTensor<ElementSFD, LayoutD> reference_SFD;
// Data Init Setting
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_D;
cutlass::Distribution::Kind init_SFA;
cutlass::Distribution::Kind init_SFB;
cutlass::Distribution::Kind init_SFD;
uint64_t seed;
};
template<typename Gemv_>
struct TestbedGemvFp4SFD : public TestbedGemvFp4SFDBase<
Gemv_,
typename Gemv_::ElementC,
typename Gemv_::EpilogueOutputOp::LayoutOutput,
typename Gemv_::EpilogueOutputOp::ElementD,
typename Gemv_::EpilogueOutputOp::LayoutOutput,
typename Gemv_::EpilogueOutputOp::ElementSFD,
typename Gemv_::EpilogueOutputOp::LayoutSFD,
typename Gemv_::EpilogueOutputOp::ElementCompute,
Gemv_::EpilogueOutputOp::kVectorSize
> {
using Base = TestbedGemvFp4SFDBase<
Gemv_,
typename Gemv_::ElementC,
typename Gemv_::EpilogueOutputOp::LayoutOutput,
typename Gemv_::EpilogueOutputOp::ElementD,
typename Gemv_::EpilogueOutputOp::LayoutOutput,
typename Gemv_::EpilogueOutputOp::ElementSFD,
typename Gemv_::EpilogueOutputOp::LayoutSFD,
typename Gemv_::EpilogueOutputOp::ElementCompute,
Gemv_::EpilogueOutputOp::kVectorSize
>;
using Base::Base;
using Gemv = Gemv_;
using ElementCompute = typename Base::ElementCompute;
using SfAtom_Input = typename Base::SfAtom_Input;
using Blk_MN_Input = typename Base::Blk_MN_Input;
using Blk_SF_Input = typename Base::Blk_SF_Input;
static constexpr int kVectorSize = Base::kVectorSize;
typename Gemv::Arguments get_arguments(
cutlass::MatrixCoord problem_size,
int32_t batch_count, float epilogue_st,
ElementCompute alpha, ElementCompute beta) override {
const int32_t gemm_m = problem_size.row();
const int32_t gemm_k = problem_size.column();
[[maybe_unused]] const int32_t gemm_n = 1;
[[maybe_unused]] const int32_t gemm_batch = batch_count;
auto k_blks_input = cutlass::ceil_div(gemm_k, cute::size<1>(shape(SfAtom_Input{})));
auto m_blks_input = cutlass::ceil_div(gemm_m, Blk_MN_Input{});
auto n_blks_input = cutlass::ceil_div(gemm_n, Blk_MN_Input{});
int batch_stride_SFA = m_blks_input * Blk_MN_Input{} * k_blks_input * Blk_SF_Input{};
int batch_stride_SFB = n_blks_input * Blk_MN_Input{} * k_blks_input * Blk_SF_Input{};
// Use the same SFD layout generation as reference to get correct batch stride
using ProblemShapeType = cute::Shape<int, int, int, int>;
auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch};
// Generate the same layout as reference uses
using Sm1xxBlockScaledOutputConfig = typename Base::Sm1xxBlockScaledOutputConfig;
auto sfd_layout = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL);
// Calculate batch stride from the generated layout
// Extract the batch stride from the 3rd dimension stride
// The stride<2> gives us the stride for the batch dimension
auto batch_stride_tuple = cute::stride<2>(sfd_layout); // This returns (_0, 8192)
int batch_stride_SFD = static_cast<int>(cute::get<1>(batch_stride_tuple)); // Extract the 8192 part
// Initialize GEMV kernel
typename Gemv::Arguments arguments{
problem_size, // problem_size
batch_count, // batch_count
typename Gemv::EpilogueOutputOp::Params{
this->tensor_D.device_ref(), // tensor_d
this->tensor_SFD.device_data(), // scale_factor_d_ptr
alpha, // alpha
beta, // beta
epilogue_st, // st
batch_stride_SFD, // batch_stride_sfd
gemm_m // stride_d
},
this->tensor_A.device_ref(), // ref_A
this->tensor_B.device_data(), // ptr_B
this->tensor_C.device_data(), // ptr_C
this->tensor_D.device_data(), // ptr_D
this->tensor_SFA.device_data(), // ptr_SFA
this->tensor_SFB.device_data(), // ptr_SFB
gemm_k, // stride_A
gemm_m * gemm_k, // batch_stride_A
gemm_k, // batch_stride_B
gemm_m, // batch_stride_C
gemm_m, // batch_stride_D
batch_stride_SFA, // batch_stride_SFA
batch_stride_SFB, // batch_stride_SFB
batch_stride_SFD // batch_stride_SFD
};
return arguments;
}
};
struct Options {
bool help = false;
int m = 4096;
int k = 2048;
int n = 1;
int batch = 1;
float alpha = 1.0f;
float beta = 0.0f;
float epilogue_st = -1.0f; // sentinel for random
bool profiling = true;
int iterations = 10;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("batch", batch);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("epilogue_st", epilogue_st);
cmd.get_cmd_line_argument("profiling", profiling);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "91_fp4_gemv\n\n"
<< " FP4 GEMV with block-scaled inputs and outputs.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --batch=<int> Sets the batch count of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --epilogue_st=<f32> Epilogue ST value\n\n"
<< " --profiling=<bool> Whether to run profiling\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "91_fp4_gemv" << " --m=4096 --k=2048 --batch=1 \n\n";
return out;
}
};
bool
run_fp4_gemv_device(Options const& options)
{
CUTLASS_ASSERT(options.n == 1);
using ElementA = cutlass::float_e2m1_t;
using ElementSFA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::float_e2m1_t;
using ElementSFB = cutlass::float_e4m3_t;
using ElementC = cutlass::float_e2m1_t;
using ElementD = cutlass::float_e2m1_t;
using LayoutD = cutlass::layout::ColumnMajor;
using ElementSFD = cutlass::float_e4m3_t;
// Indicate SF is computed along col dim. Does NOT indicate actual layout of SFD
using LayoutSFD = cutlass::layout::ColumnMajor;
using ElementAccumulatorMainloop = cutlass::half_t;
using ElementAccumulator = float;
using ElementCompute = float;
ElementCompute alpha{options.alpha};
ElementCompute beta{options.beta};
// Must be a positive number.
const float epilogue_st = options.epilogue_st < 0.f ?
static_cast<float>(rand()) / (static_cast<float>(RAND_MAX / 5)) :
options.epilogue_st;
static constexpr int kVectorSize = 16;
static constexpr int kElementsPerAccess = 128 / cutlass::sizeof_bits<ElementA>::value;
using ThreadShape = cutlass::gemm::GemmShape<16, 8>;
static_assert(kVectorSize == ThreadShape::kM, "vector size and thread in row should be equal");
// Construct Epilogue
using EpilogueOp = typename cutlass::epilogue::threadblock::GemvEpilogueWithScalingFactor<kVectorSize,
ThreadShape,
ElementCompute,
ElementAccumulator,
ElementC,
ElementD,
ElementSFD,
LayoutD,
LayoutSFD>;
// Construct Mainloop
using Gemv = cutlass::gemm::device::GemvBlockScaled<
cutlass::gemm::kernel::
GemvBlockScaled<ElementA, LayoutA, ElementB, ElementD, ElementAccumulatorMainloop, EpilogueOp, kElementsPerAccess>>;
TestbedGemvFp4SFD<Gemv> testbed;
bool pass = true;
if (options.profiling) {
pass = testbed.profile(cutlass::MatrixCoord{options.m, options.k}, options.batch, alpha, beta, epilogue_st, options.iterations);
}
else {
pass = testbed.run_and_verify(cutlass::MatrixCoord{options.m, options.k}, options.batch, alpha, beta, epilogue_st);
}
return pass;
}
int
main(int argc, char const** argv)
{
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Run verification
Options verification_options = options;
verification_options.profiling = false;
bool passed = run_fp4_gemv_device(verification_options);
if (passed == false) {
printf("test fail\n");
return 1;
} else {
printf("test pass\n");
}
if (options.profiling) {
// Start profiling
printf("\nProfiling...\n");
passed = run_fp4_gemv_device(options);
if (passed == false) {
printf("profiling fail\n");
return 1;
} else {
printf("profiling completed\n");
}
}
return 0;
#else
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM100_SUPPORTED is defined.\n";
return 0;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
}

View File

@ -0,0 +1,36 @@
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (NOT MSVC)
cutlass_example_add_executable(
91_fp4_gemv
91_fp4_gemv.cu
)
endif()

View File

@ -0,0 +1,701 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of Blackwell MoE-style grouped NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B.
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
Usage:
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped
--m=28672 --n=4 --k=4096 --l=8 --benchmark=benchmark.txt
*/
#include <iostream>
#include <fstream>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "helper.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help;
bool error;
bool verification;
int m, n, k, l;
int iterations;
std::string benchmark_path;
Options():
help(false),
error(false),
verification(true),
m(2048), n(2048), k(2048), l(1),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 2048);
cmd.get_cmd_line_argument("n", n, 2048);
cmd.get_cmd_line_argument("k", k, 2048);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("iterations", iterations, 10);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "92_blackwell_moe_gemm_fp4_grouped\n\n"
<< " Blackwell MoE-style grouped NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
<< " --benchmark=<file> Executes a benchmark problem size\n"
<< " --no_verif Do not run verification kernels\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <class Element, class Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t> || cute::is_same_v<Element, cutlass::float_ue4m3_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
template <class T>
auto make_iterator(T* ptr) {
return cute::recast_ptr<T>(ptr);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleRunner {
// Type of kernel schedule to generate
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100;
// Type of epilogue schedule to generate
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
static constexpr bool FuseQuantization = false;
using LayoutATag = cutlass::layout::RowMajor;
using LayoutBTag = cutlass::layout::ColumnMajor;
using LayoutCTag = cutlass::layout::ColumnMajor;
using LayoutDTag = cutlass::layout::ColumnMajor;
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands
using ElementA = cutlass::nv_float4_t<ElementInput>;
using ElementB = cutlass::nv_float4_t<ElementInput>;
using ElementC = cutlass::half_t;
using ElementD = cute::conditional_t<FuseQuantization, ElementInput, ElementC>;
using ElementSFD = ElementSF;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int OutputSFVectorSize = 16;
// D = alpha * acc + beta * C
// With BlockScaleFactor generation.
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
OutputSFVectorSize,
ElementD,
ElementCompute,
ElementSFD, LayoutSFDTag,
ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
MmaTileMNK, ClusterShapeMNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
EpilogueScheduleType,
cute::conditional_t<
FuseQuantization,
FusionOperation,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
MmaTileMNK, ClusterShapeMNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::CollectiveOp;
using ProblemShapeGroup = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ProblemShapeMax = Shape<int,int,int,int>; // max <M,N,K,L>
using ProblemShape = cutlass::gemm::MoEProblemShape<ProblemShapeGroup, ProblemShapeMax>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using FusionOp = typename Gemm::EpilogueOutputOp;
static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutSFD layout_SFD;
uint64_t seed = 0;
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
cutlass::DeviceAllocation<typename ProblemShapeGroup::UnderlyingProblemShape> problem_sizes;
//
// Methods
//
bool verify(ProblemShape const& problem_size, float alpha, float beta) {
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
// think about how to simplify the gemm3x interface.
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
if constexpr (FuseQuantization) {
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementCompute, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D), // TensorD
decltype(tensor_SFD), // TensorSfD
cute::Int<OutputSFVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params {alpha, beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
else {
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementCompute, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params {alpha, beta, tensor_C, tensor_D };
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
bool passed = true;
// Comparison
block_D.sync_host();
auto [maxM, maxN, maxK, L] = problem_size.max_problem_shape;
for (int i = 0; i < problem_size.problem_shape.num_groups; i++) {
auto problem = problem_size.problem_shape.get_host_problem_shape(i);
auto [M, N, K] = problem;
// assume all M == maxM
auto refD_view = block_reference_D.host_view().subview(cutlass::make_Coord(M * N), cutlass::make_Coord(i * maxN * maxM));
auto D_view = block_D.host_view().subview(cutlass::make_Coord(M * N), cutlass::make_Coord(i * maxN * maxM));
passed &= cutlass::reference::host::TensorEquals(refD_view, D_view);
}
return passed;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(ProblemShape const& problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size.max_problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// For SFD tensor layout
using Sm1xxBlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// printf("\nStrideC = ");
// print(StrideC{});
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, L});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L});
// printf("\nstride_C = ");
// print(stride_C);
layout_A = make_layout(make_shape(M, K, L), stride_A);
layout_B = make_layout(make_shape(N, K, L), stride_B);
layout_C = make_layout(make_shape(M, N, L), stride_C);
layout_D = make_layout(make_shape(M, N, L), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, L));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, L));
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, L));
// printf("\nlayout_A = ");
// print(layout_A);
// printf("\nlayout_B = ");
// print(layout_B);
// printf("\nlayout_C = ");
// print(layout_C);
// printf("\nsize(layout_A)=%lld", (long long)size(layout_A));
// printf("\n");
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_Normconst.reset(cutlass::make_Coord(1));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_Normconst.at(cutlass::make_Coord(0)) = 2;
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_D.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
block_SFD.sync_device();
block_Normconst.sync_device();
}
/// Load a benchmark
std::vector<ProblemShapeGroup::UnderlyingProblemShape> benchmark_problems(std::string const& benchmark_path) {
std::vector<ProblemShapeGroup::UnderlyingProblemShape> problem_sizes_host;
std::ifstream file(benchmark_path);
if (!file.good()) {
return {};
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
cutlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
extent.at(i) = std::atoi(tokens.at(i).c_str());
}
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
return problem_sizes_host;
}
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
auto problem_sizes_host = benchmark_problems(options.benchmark_path);
if (problem_sizes_host.empty()) {
return false;
}
problem_sizes.reset(problem_sizes_host.size());
problem_sizes.copy_from_host(problem_sizes_host.data());
ProblemShape problem_size;
problem_size.max_problem_shape = ProblemShapeMax{options.m, options.n, options.k, options.l};
problem_size.problem_shape.num_groups = problem_sizes_host.size();
problem_size.problem_shape.problem_shapes = problem_sizes.get();
problem_size.problem_shape.host_problem_shapes = problem_sizes_host.data();
initialize(problem_size);
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
problem_size,
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
},
hw_info
};
auto f = [&](auto blockscale) {
auto impl = [this](auto& arguments) {
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
};
if constexpr (decltype(blockscale)::value) {
impl(arguments);
}
};
f(std::bool_constant<IsBlockScaleSupported>());
// arguments.scheduler.max_swizzle_size = options.swizzle;
arguments.epilogue.thread.alpha = 1.0f;
arguments.epilogue.thread.beta = 0.0f;
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
// Run the GEMM
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
if (options.verification) {
// Verify that the result is correct
bool passed = verify(problem_size, 1.0f, 0.0f);
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if (!passed) {
exit(-1);
return false;
}
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_op.run());
}
timer.stop();
// Compute average setup and runtime and FLOPs.
float elapsed_ms = timer.elapsed_millis();
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
}
return true;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
ExampleRunner runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
#endif
return 0;
}

View File

@ -0,0 +1,654 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of Blackwell MoE-style NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
This example assumes all experts have the same number of tokens, in which case the GEMM becomes a regular (batched) gemm.
For the realistic use case where each expert may have different number of tokens (grouped GEMM), check 92_blackwell_moe_gemm_fp4_grouped.
Usage:
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular
--m=28672 --n=4 --k=4096 --l=8
*/
#include <iostream>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "helper.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help;
bool error;
bool verification;
int m, n, k, l;
int iterations;
Options():
help(false),
error(false),
verification(true),
m(2048), n(2048), k(2048), l(1),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 2048);
cmd.get_cmd_line_argument("n", n, 2048);
cmd.get_cmd_line_argument("k", k, 2048);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("iterations", iterations, 10);
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "92_blackwell_moe_gemm_fp4_regular\n\n"
<< " Blackwell NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
<< " --no_verif Do not run verification kernels\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <class Element, class Layout>
bool initialize_block(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6) {
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8) {
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t> || cute::is_same_v<Element, cutlass::float_ue4m3_t>) {
scope_max = 4;
scope_min = 1;
}
else {
scope_max = 1;
scope_min = -1;
}
}
else{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
template <class T>
auto make_iterator(T* ptr) {
return cute::recast_ptr<T>(ptr);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// MSVC complain about it if moved to ExampleRunner
static constexpr int OutputSFVectorSize = 16;
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the
// number of pipeline stages.
template <
// Type of kernel schedule to generate
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
// Type of epilogue schedule to generate
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
bool FuseQuantization = false
>
struct ExampleRunner {
using LayoutATag = cutlass::layout::RowMajor;
using LayoutBTag = cutlass::layout::ColumnMajor;
using LayoutCTag = cutlass::layout::ColumnMajor;
using LayoutDTag = cutlass::layout::ColumnMajor;
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands
using ElementA = cutlass::nv_float4_t<ElementInput>;
using ElementB = cutlass::nv_float4_t<ElementInput>;
using ElementC = cutlass::half_t;
using ElementD = cute::conditional_t<FuseQuantization, ElementInput, ElementC>;
using ElementSFD = ElementSF;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// D = alpha * acc + beta * C
// With BlockScaleFactor generation.
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
OutputSFVectorSize,
ElementD,
ElementCompute,
ElementSFD, LayoutSFDTag,
ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
MmaTileMNK, ClusterShapeMNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
EpilogueScheduleType,
cute::conditional_t<
FuseQuantization,
FusionOperation,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
MmaTileMNK, ClusterShapeMNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
using FusionOp = typename Gemm::EpilogueOutputOp;
static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
//
// Data members
//
/// Initialization
StrideA stride_A;
LayoutA layout_A;
LayoutSFA layout_SFA;
StrideB stride_B;
LayoutB layout_B;
LayoutSFB layout_SFB;
StrideC stride_C;
LayoutC layout_C;
StrideD stride_D;
LayoutD layout_D;
LayoutSFD layout_SFD;
uint64_t seed = 0;
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
//
// Methods
//
bool verify(ProblemShapeType const& problem_size, float alpha, float beta) {
// Create the arguments for host reference implementation
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
// think about how to simplify the gemm3x interface.
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator, // ElementAccumulator
decltype(tensor_A), // TensorA
decltype(tensor_SFA), // TensorSfA
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
if constexpr (FuseQuantization) {
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementCompute, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D), // TensorD
decltype(tensor_SFD), // TensorSfD
cute::Int<OutputSFVectorSize>,
cutlass::reference::host::SfStrategy::SfDGen
> epilogue_params {alpha, beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
else {
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementCompute, // ElementScalar
ElementAccumulator, // ElementAccumulator
ElementCompute, // ElementCompute
decltype(tensor_C), // TensorC
decltype(tensor_D) // TensorD
> epilogue_params {alpha, beta, tensor_C, tensor_D };
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
}
bool passed = true, passed_sfd = true;
// Comparison
block_D.sync_host();
passed &= cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
if constexpr (FuseQuantization) {
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
block_SFD.sync_host();
passed_sfd &= cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view());
passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0);
passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0);
}
// printf("passed=%d\n", (int)passed);
// printf("passed_sfd=%d\n", (int)passed_sfd);
return passed && passed_sfd;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(ProblemShapeType const& problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;
// For SFA and SFB tensors layouts
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// For SFD tensor layout
using Sm1xxBlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
// printf("\nStrideC = ");
// print(StrideC{});
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L});
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L});
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, L});
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L});
// printf("\nstride_C = ");
// print(stride_C);
layout_A = make_layout(make_shape(M, K, L), stride_A);
layout_B = make_layout(make_shape(N, K, L), stride_B);
layout_C = make_layout(make_shape(M, N, L), stride_C);
layout_D = make_layout(make_shape(M, N, L), stride_D);
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, L));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, L));
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, L));
// printf("\nlayout_A = ");
// print(layout_A);
// printf("\nlayout_B = ");
// print(layout_B);
// printf("\nlayout_C = ");
// print(layout_C);
// printf("\nsize(layout_A)=%lld", (long long)size(layout_A));
// printf("\n");
block_A.reset(cutlass::make_Coord(size(layout_A)));
block_B.reset(cutlass::make_Coord(size(layout_B)));
block_C.reset(cutlass::make_Coord(size(layout_C)));
block_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
block_Normconst.reset(cutlass::make_Coord(1));
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
initialize_block(block_A.host_view(), seed + 2021);
initialize_block(block_B.host_view(), seed + 2022);
initialize_block(block_C.host_view(), seed + 2023);
initialize_block(block_SFA.host_view(), seed + 2024);
initialize_block(block_SFB.host_view(), seed + 2025);
block_Normconst.at(cutlass::make_Coord(0)) = 2;
block_A.sync_device();
block_B.sync_device();
block_C.sync_device();
block_D.sync_device();
block_SFA.sync_device();
block_SFB.sync_device();
block_SFD.sync_device();
block_Normconst.sync_device();
}
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
initialize(problem_size);
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{ // Mainloop arguments
block_A.device_data(), stride_A,
block_B.device_data(), stride_B,
block_SFA.device_data(), layout_SFA,
block_SFB.device_data(), layout_SFB
},
{ // Epilogue arguments
{},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
},
hw_info
};
if constexpr (IsBlockScaleSupported) {
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
}
// arguments.scheduler.max_swizzle_size = options.swizzle;
arguments.epilogue.thread.alpha = 1.0f;
arguments.epilogue.thread.beta = 0.0f;
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
// Run the GEMM
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
if (options.verification) {
// Verify that the result is correct
bool passed = verify(problem_size, 1.0f, 0.0f);
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if (!passed) {
exit(-1);
return false;
}
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_op.run());
}
timer.stop();
// Compute average setup and runtime and FLOPs.
float elapsed_ms = timer.elapsed_millis();
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
}
return true;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
std::cout << "Running kernel with TMA load:" << std::endl;
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100> runner_tma;
runner_tma.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100> runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
#endif
return 0;
}

View File

@ -0,0 +1,541 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of Blackwell MoE-style grouped GEMM implementation using TMA to load A and CPASYNC to load B.
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
Usage:
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped
--m=28672 --n=4 --k=4096 --l=8 --benchmark=benchmark.txt
*/
#include <iostream>
#include <fstream>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help;
bool error;
bool verification;
int m, n, k, l;
int iterations;
std::string benchmark_path;
Options():
help(false),
error(false),
verification(true),
m(2048), n(2048), k(2048), l(1),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 2048);
cmd.get_cmd_line_argument("n", n, 2048);
cmd.get_cmd_line_argument("k", k, 2048);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("iterations", iterations, 10);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "92_blackwell_moe_gemm_grouped\n\n"
<< " Blackwell MoE-style grouped GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
<< " --benchmark=<file> Executes a benchmark problem size\n"
<< " --no_verif Do not run verification kernels\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(0);
}
else if (bits_input <= 8) {
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(-2);
}
else {
scope_max = static_cast<Element>(8);
scope_min = static_cast<Element>(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleRunner {
// Type of kernel schedule to generate
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100;
// Type of epilogue schedule to generate
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
// 16B alignment lets us use TMA
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileMNK, ClusterShapeMNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueScheduleType,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
MmaTileMNK, ClusterShapeMNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::CollectiveOp;
using ProblemShapeGroup = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ProblemShapeMax = Shape<int,int,int,int>; // max <M,N,K,L>
using ProblemShape = cutlass::gemm::MoEProblemShape<ProblemShapeGroup, ProblemShapeMax>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
//, cutlass::gemm::MoEScheduler
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed = 0;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
cutlass::DeviceAllocation<typename ProblemShapeGroup::UnderlyingProblemShape> problem_sizes;
//
// Methods
//
bool verify(ProblemShape const& problem_size, float alpha, float beta) {
auto [maxM, maxN, maxK, L] = problem_size.max_problem_shape;
for (int i = 0; i < problem_size.problem_shape.num_groups; i++) {
auto problem = problem_size.problem_shape.get_host_problem_shape(i);
auto [M, N, K] = problem;
cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, Gemm::LayoutA(maxK));
cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, Gemm::LayoutB(maxK));
cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutC(maxM));
cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutD(maxM));
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementScalar,
ElementAccumulator>;
DeviceGemmReference gemm_reference;
gemm_reference(
{M, N, K},
ElementScalar(alpha),
ref_A,
ref_B,
ElementScalar(beta),
ref_C,
ref_D);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
// assume all M == maxM
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + size_t(1) * i * maxN * maxM, block_D.get() + size_t(1) * i * maxN * maxM, M * N);
if (!passed) {
return false;
}
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(ProblemShape const& problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size.max_problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
block_A.reset(size_t(1) * M * K * L);
block_B.reset(size_t(1) * K * N * L);
block_C.reset(size_t(1) * M * N * L);
block_D.reset(size_t(1) * M * N * L);
block_ref_D.reset(size_t(1) * M * N * L);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
/// Load a benchmark
std::vector<ProblemShapeGroup::UnderlyingProblemShape> benchmark_problems(std::string const& benchmark_path) {
std::vector<ProblemShapeGroup::UnderlyingProblemShape> problem_sizes_host;
std::ifstream file(benchmark_path);
if (!file.good()) {
return {};
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
cutlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
extent.at(i) = std::atoi(tokens.at(i).c_str());
}
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
return problem_sizes_host;
}
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
auto problem_sizes_host = benchmark_problems(options.benchmark_path);
if (problem_sizes_host.empty()) {
return false;
}
problem_sizes.reset(problem_sizes_host.size());
problem_sizes.copy_from_host(problem_sizes_host.data());
ProblemShape problem_size;
problem_size.max_problem_shape = ProblemShapeMax{options.m, options.n, options.k, options.l};
problem_size.problem_shape.num_groups = problem_sizes_host.size();
problem_size.problem_shape.problem_shapes = problem_sizes.get();
problem_size.problem_shape.host_problem_shapes = problem_sizes_host.data();
initialize(problem_size);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
problem_size,
{block_A.get(), stride_A, block_B.get(), stride_B},
{{}, // epilogue.thread
block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
// arguments.scheduler.max_swizzle_size = options.swizzle;
arguments.epilogue.thread.alpha = 1.0f;
arguments.epilogue.thread.beta = 0.0f;
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
// Run the GEMM
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
if (options.verification) {
// Verify that the result is correct
bool passed = verify(problem_size, 1.0f, 0.0f);
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if (!passed) {
exit(-1);
return false;
}
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_op.run());
}
timer.stop();
// Compute average setup and runtime and FLOPs.
float elapsed_ms = timer.elapsed_millis();
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
}
return true;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
ExampleRunner runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
#endif
return 0;
}

View File

@ -0,0 +1,484 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of Blackwell MoE-style GEMM implementation using TMA to load A and CPASYNC to load B.
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
Usage:
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular
--m=28672 --n=4 --k=4096 --l=8
*/
#include <iostream>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Command line options parsing
struct Options {
bool help;
bool error;
bool verification;
int m, n, k, l;
int iterations;
Options():
help(false),
error(false),
verification(true),
m(2048), n(2048), k(2048), l(1),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 2048);
cmd.get_cmd_line_argument("n", n, 2048);
cmd.get_cmd_line_argument("k", k, 2048);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("iterations", iterations, 10);
if (cmd.check_cmd_line_flag("no_verif")) {
verification = false;
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "92_blackwell_moe_gemm_regular\n\n"
<< " Blackwell GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
<< " --no_verif Do not run verification kernels\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(0);
}
else if (bits_input <= 8) {
scope_max = static_cast<Element>(2);
scope_min = static_cast<Element>(-2);
}
else {
scope_max = static_cast<Element>(8);
scope_min = static_cast<Element>(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the
// number of pipeline stages.
template <
// Type of kernel schedule to generate
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
// Type of epilogue schedule to generate
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto
>
struct ExampleRunner {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
// 16B alignment lets us use TMA
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileMNK, ClusterShapeMNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueScheduleType,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
MmaTileMNK, ClusterShapeMNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed = 0;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
//
// Methods
//
bool verify(ProblemShapeType const& problem_size, float alpha, float beta) {
auto [M, N, K, L] = problem_size;
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));
cutlass::reference::device::GemmComplex(
{M, N, K},
ElementScalar(alpha),
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
ElementScalar(beta),
ref_C,
ref_D,
ElementAccumulator(0),
L, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
M * N, // batch_stride_C
M * N // batch_stride_D
);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Reference kernel failed. Last CUDA error: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
return passed;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(ProblemShapeType const& problem_size) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
block_A.reset(size_t(1) * M * K * L);
block_B.reset(size_t(1) * K * N * L);
block_C.reset(size_t(1) * M * N * L);
block_D.reset(size_t(1) * M * N * L);
block_ref_D.reset(size_t(1) * M * N * L);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
initialize(problem_size);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A.get(), stride_A, block_B.get(), stride_B},
{{}, // epilogue.thread
block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
// arguments.scheduler.max_swizzle_size = options.swizzle;
arguments.epilogue.thread.alpha = 1.0f;
arguments.epilogue.thread.beta = 0.0f;
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "This kernel is not supported. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
// Run the GEMM
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
return false;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
<< cudaGetErrorString(result) << std::endl;
return false;
}
if (options.verification) {
// Verify that the result is correct
bool passed = verify(problem_size, 1.0f, 0.0f);
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if (!passed) {
exit(-1);
return false;
}
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_op.run());
}
timer.stop();
// Compute average setup and runtime and FLOPs.
float elapsed_ms = timer.elapsed_millis();
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
}
return true;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
std::cout << "Running kernel with TMA load:" << std::endl;
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmSm100> runner_tma;
runner_tma.run(options, hw_info);
std::cout << "Running kernel with CPASYNC load:" << std::endl;
ExampleRunner<cutlass::gemm::KernelWarpSpecialized1SmSm100> runner_cpasync;
runner_cpasync.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100> runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
#endif
return 0;
}

View File

@ -0,0 +1,70 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(TEST_MIXTRAL_A --m=28672 --n=4 --k=4096 --l=8)
set(TEST_MIXTRAL_B --m=4096 --n=4 --k=14336 --l=8)
set(TEST_DEEPSEEK_A --m=4096 --n=1 --k=7168 --l=256)
set(TEST_DEEPSEEK_B --m=7168 --n=1 --k=2048 --l=256)
set(TEST_IRREGULAR_MNK --m=4080 --n=9 --k=4112 --l=8) # M,N,K not multiples of tile size
set(TEST_DEEPSEEK_A_FP4 --m=1024 --n=1 --k=7168 --l=256) # TP=1 shape is too large for PackedVectorLayout
set(TEST_DEEPSEEK_B_FP4 --m=7168 --n=1 --k=512 --l=256)
set(TEST_IRREGULAR_MNK_FP4 --m=4080 --n=9 --k=4160 --l=8)
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
92_blackwell_moe_gemm_regular
92_blackwell_moe_gemm_regular.cu
TEST_COMMAND_OPTIONS
TEST_MIXTRAL_A
TEST_MIXTRAL_B
TEST_DEEPSEEK_A
TEST_DEEPSEEK_B
TEST_IRREGULAR_MNK
)
cutlass_example_add_executable(
92_blackwell_moe_gemm_grouped
92_blackwell_moe_gemm_grouped.cu
)
cutlass_example_add_executable(
92_blackwell_moe_gemm_fp4_regular
92_blackwell_moe_gemm_fp4_regular.cu
TEST_COMMAND_OPTIONS
TEST_MIXTRAL_A
TEST_MIXTRAL_B
TEST_DEEPSEEK_A_FP4
TEST_DEEPSEEK_B_FP4
TEST_IRREGULAR_MNK_FP4
)
cutlass_example_add_executable(
92_blackwell_moe_gemm_fp4_grouped
92_blackwell_moe_gemm_fp4_grouped.cu
)
endif()

View File

@ -163,7 +163,13 @@ foreach(EXAMPLE
82_blackwell_distributed_gemm
83_blackwell_sparse_gemm
84_blackwell_narrow_precision_sparse_gemm
86_blackwell_mixed_dtype_gemm
87_blackwell_geforce_gemm_blockwise
88_hopper_fmha
89_sm103_fp4_ultra_gemm
90_sm103_fp4_ultra_grouped_gemm
91_fp4_gemv
92_blackwell_moe_gemm
)
add_subdirectory(${EXAMPLE})

Some files were not shown because too many files have changed in this diff Show More