Compare commits

...

65 Commits

Author SHA1 Message Date
8afb19d904 update CITATION.cff 2025-10-28 23:42:37 -04:00
b2ca083d2b Fixed compilation error when using StreamK scheduler + PDL. (#2686) 2025-10-21 23:11:14 -04:00
b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00
e6e2cc29f5 fix (#2684) 2025-10-15 14:46:38 -04:00
c6aeb9179c Update pyproject.toml
update version to 4.2.1
2025-09-24 01:18:51 -04:00
95a5ff14c0 Update CHANGELOG.md
format change
2025-09-23 17:33:00 -04:00
fb8b43ef05 Merge pull request #2669 from NVIDIA/421_update
4.2.1 update
2025-09-23 14:02:29 -07:00
f874df19ac 4.2.1 update 2025-09-23 13:45:13 -07:00
7a6d4ee099 v4.2.1 update. (#2666) 2025-09-23 13:25:43 -04:00
GTO
2b8dff1f90 Fix bfloat16 epsilon (#2607)
* Fix bfloat16 epsilon

* just use constants

---------

Co-authored-by: Konstantin <konstantin@MacBook-Air.local>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2025-09-21 23:43:59 -04:00
fd0312ddf6 Remove duplicate function calls (#1584) 2025-09-21 23:16:59 -04:00
64579189ec Feature/add bottom causal mask (#2480)
* Rebase to latest

* update

* upd

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update fmha_fusion.hpp

* Update fmha_fusion.hpp

fixed flipped logic for isQBegin

* Update fmha_fusion.hpp

* Avoid use of booleans

The current expression is confusing

* fmt

* Update fmha_fusion.hpp

Reproduce error/fix with: 
./77_blackwell_fmha_fp16 --verify --b=1 --q=1013 --k=1024 --h=1 --h_k=1 --mask=causal --causal-type=qend

* add test, format

---------

Co-authored-by: Richard Cai <ricai@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-09-18 17:11:23 -04:00
b234a8c024 Rename python/cutlass to python/cutlass_cppgen (#2652) 2025-09-18 14:26:57 -04:00
74825181f2 Remove old-version dsl examples. (#2644) 2025-09-17 22:23:30 -04:00
8825e8be4f Add required changes for github pipeline. (#2648) 2025-09-17 22:22:45 -04:00
wbn
7817e47154 Fxied a typo in pipeline descript docs. (#2623) 2025-09-15 22:32:27 -04:00
25ccb875b8 Fix: a calculation error in the example of dividing out in the 02_layout_algebra doc (#2635) 2025-09-15 22:31:33 -04:00
29c1ad704a Fix doc cute 03_tensor.md link typo (#2627)
* Update 03_tensor.md fix link typo

change path to relative path

* Update 03_tensor.md

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-09-15 22:26:43 -04:00
57e3cfb47a doc change for 4.2 (#2639)
* doc change

* fix broken links

* ragged gemm doc update

* move around texts about moe gemm
2025-09-15 22:02:45 -04:00
e7e0adddac Update version.h
change version number to 4.2
2025-09-15 12:40:58 -04:00
6a35b4d22f v4.2 tag release. (#2638) 2025-09-15 12:21:53 -04:00
56f0718a97 ex77 backwards GQA (#2556)
* bwd GQA init

* Update examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu

* ref kernel type conversion fix

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2025-09-09 12:53:28 -04:00
76c96b0be3 Fix incorrect shapes in copy_atom doc comments. (#2575) 2025-09-04 16:57:24 -07:00
d98e7bf7ce Fix comment in mma_atom.hpp (#2579) 2025-09-04 16:56:39 -07:00
b6ccf34aef Fix Copy_Atom type mismatch in sgemm_sm80.cu (#2582) 2025-09-04 16:56:17 -07:00
2288c0c901 Fix bugs in matrix.h (#2598) 2025-09-04 16:55:11 -07:00
b2dd65dc86 more robust imports in heuristics.py and heuristics_provider.py (#2596) 2025-08-28 22:32:55 -04:00
496654bf2c Fix sm100 gemm wrong static constexpr that breaks compilation on Windows (#2167)
* Fix a sm100 gemm wrong defined static constexpr that breaks compilation on Windows

* Fix a sm100 gemm wrong defined static constexpr that breaks compilation on Windows

* More Windows fixes

Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>

* Revert "More Windows fixes"

This reverts commit 2e8cfc1382.

---------

Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>
2025-08-28 22:13:00 -04:00
9ca7e877b2 fix gqa issue for blackwell fmha.py (#2599) 2025-08-28 11:15:20 -04:00
a49a78ffef v4.2 release. (#2587)
* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line.

* v4.2 release.
2025-08-22 18:11:24 -04:00
11cad1f67b fix a typo. (#2561) 2025-08-19 22:23:09 -04:00
931359cec1 Fix typo in functional.h (#2571) 2025-08-19 22:22:31 -04:00
42e7c546c4 Add movmatrix support (movmatrix.sync.aligned.m8n8.trans.b16) (#2562) 2025-08-19 22:22:02 -04:00
ec18e8043b Make swizzle in pycute work (#2553) 2025-08-19 22:21:00 -04:00
5b76420d6a [DOC] Add more exposition to composition example (#2536)
* Add more exposition to composition example

* Apply suggestions from code review

Co-authored-by: Cris Cecka <ccecka@users.noreply.github.com>

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Cris Cecka <ccecka@users.noreply.github.com>
2025-08-11 22:20:36 -04:00
19772cd63e Fix typo in smem_allocator.py (#2517) 2025-08-10 22:44:22 -04:00
052afcd314 fix typo (#2529) 2025-08-10 22:44:02 -04:00
86cf63e2d4 NIT: Grammar (#2537) 2025-08-10 22:42:45 -04:00
a267d47f9b Update batched_gemm.cu (#2538) 2025-08-10 22:42:21 -04:00
9e6ab77d27 Fix a copy error in the SM70 main loop when loading data from smem to rmem (#2540) 2025-08-10 22:42:01 -04:00
d0eada85a3 Support both CUDA 12 and 13 cccl header locations (#2543) 2025-08-10 22:41:25 -04:00
23139309e9 Fix incorrect K dim in CuTe MMA Atom doc. (#2544) 2025-08-10 22:40:56 -04:00
6dd13d4278 Facebook:This commit makes its files safe for use with -Wimplicit-fallthrough. (#2324) 2025-07-31 20:55:19 -04:00
3b054767b3 Fix typo (#2514) 2025-07-30 22:14:54 -04:00
6fb5e667c1 [Doc fix] incorrect compute cap. for Blackwell RTX (#2511)
Blackwell RTX is compute capability 12.0 (SM120) but incorrectly listed
as SM100 in the README.
2025-07-30 22:14:13 -04:00
6c891db9f6 Fix epilogue:🧵:Convert cannot be used with cute::collective::DefaultEpilogue. (#2333) 2025-07-30 22:12:53 -04:00
da47886e34 Fix example bug (#2351) 2025-07-30 22:12:33 -04:00
26b7450023 support fp16 accmulator for sm89 fp8 mma (#2378)
* add support for sm89 in cute and the unit tests

* support fp16 accmulator for sm89 fp8 mma

* format code
2025-07-30 22:12:08 -04:00
a39cf6b511 Fix example in CuTe tutorials (#2416) 2025-07-30 22:11:47 -04:00
f09045d660 Corrected minor nit in mma_traits.hpp (#2447)
* Corrected minor nit in mma_traits.hpp

The entry and descriptions were jumbled up.

* Update mma_traits.hpp

* Update mma_traits.hpp
2025-07-30 22:11:23 -04:00
84a27b3926 fix: examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu GridDim miscalculated (#2492)
* fix: examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu Launch dimGrid error

* feat: add cta tiler

* Update examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu

use cluster_layout_vmnk instead of cta_tiler

Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>

* feat: remove cta_tiler

---------

Co-authored-by: qinghongzeng <qinghongzeng@deeproute.ai>
Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>
2025-07-30 22:11:04 -04:00
e093b4f691 Fix tutorial comment in sgemm_1.cu: use tCrC instead of tCsA in axpby explanation (#2448) 2025-07-30 22:09:55 -04:00
664c4f7b3e Update CUTLASS version to 4.1
Update CUTLASS version to 4.1.
2025-07-26 20:11:04 -04:00
0e026982ce Example 77 add blackwell fmha bwd for MLA shape (#2466)
* Update examples/77_blackwell_fmha/device/fmha_device_bwd.hpp

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>

* bug fix & use existing value rather than pass one more argument to support different dim in bwd_convert

* Fix casual mask cnt when IsQBegin==false

* bug fix in casual mask backward

* code sync

---------

Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
2025-07-24 18:41:11 -04:00
9a9a579714 Merge pull request #2489 from NVIDIA/update_workflow_script
Support "CuTe DSL" auto-labeling in workflow
2025-07-23 15:33:43 +08:00
51d730b8be Support "CuTe DSL" auto-labeling in workflow 2025-07-23 00:28:01 -07:00
6c0c8b7484 1. Update bug/feature report template to add component selection. (#2485)
2. Add workflow to apply component label automatically
2025-07-22 12:38:03 -04:00
e51efbfe18 Update CHANGELOG.md 2025-07-21 22:09:56 -04:00
fd6cfe1ed0 v4.1 release update v2. (#2481) 2025-07-21 22:03:55 -04:00
9baa06dd57 Add Blackwell MLA forward (shape: d=192, dv=128) implementation in example_77 (#2472) 2025-07-18 01:27:48 -04:00
ebe98c549a cache procedural_name in GemmOperation (#2317) 2025-07-16 22:25:02 -04:00
9892624b66 Fix typos in the text (#2417) 2025-07-16 21:51:12 -04:00
a1aaf2300a v4.1 release 2025-07-03 08:07:53 -04:00
b995f93317 4.0 doc change (#2425) 2025-06-27 09:35:06 -04:00
889ff20648 v4.0 update v2. (#2420)
* Ex77 forward kernel fix.
2025-06-25 12:56:25 -04:00
768 changed files with 130533 additions and 21942 deletions

View File

@ -1,23 +0,0 @@
---
name: Bug report
about: Create a bug report to help us improve CUTLASS
title: "[BUG]"
labels: "? - Needs Triage, bug"
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps/Code to reproduce bug**
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment details (please complete the following information):**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
**Additional context**
Add any other context about the problem here.

38
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: Bug Report
description: Create a bug report to help us improve CUTLASS
title: "[BUG] "
labels: ["? - Needs Triage", "bug"]
assignees: []
body:
- type: dropdown
id: component
attributes:
label: Which component has the problem?
options:
- CuTe DSL
- CUTLASS C++
validations:
required: true
- type: textarea
id: bug-report
attributes:
label: Bug Report
description: Please fill out all sections below
value: |
**Describe the bug**
A clear and concise description of what the bug is.
**Steps/Code to reproduce bug**
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment details (please complete the following information):**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
**Additional context**
Add any other context about the problem here.
validations:
required: true

View File

@ -1,20 +0,0 @@
---
name: Feature request
about: Suggest an idea for CUTLASS
title: "[FEA]"
labels: "? - Needs Triage, feature request"
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context, code examples, or references to existing implementations about the feature request here.

View File

@ -0,0 +1,35 @@
name: Feature Request
description: Suggest an idea for CUTLASS
title: "[FEA] "
labels: ["? - Needs Triage", "feature request"]
assignees: []
body:
- type: dropdown
id: component
attributes:
label: Which component requires the feature?
options:
- CuTe DSL
- CUTLASS C++
validations:
required: true
- type: textarea
id: feature-request
attributes:
label: Feature Request
description: Please fill out all sections below
value: |
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context, code examples, or references to existing implementations about the feature request here.
validations:
required: true

51
.github/workflows/auto-label-issues.yml vendored Normal file
View File

@ -0,0 +1,51 @@
name: Auto Label Issues
on:
issues:
types: [opened]
jobs:
add-labels:
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Add component label
uses: actions/github-script@v7
with:
script: |
const issue = context.payload.issue;
const body = issue.body || '';
// Parse the issue body to find the component selection
// GitHub renders dropdown selections as "### {label}\n\n{selection}"
// Check for both bug report and feature request dropdown labels
const bugComponentMatch = body.match(/### Which component has the problem\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
const featureComponentMatch = body.match(/### Which component requires the feature\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
const componentMatch = bugComponentMatch || featureComponentMatch;
if (componentMatch) {
const component = componentMatch[1].trim();
let label = '';
// Map component selections to labels
switch(component) {
case 'CuTe DSL':
label = 'CuTe DSL';
break;
case 'CUTLASS C++':
label = 'CUTLASS C++';
break;
}
if (label) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issue.number,
labels: [label]
});
console.log(`Added label: ${label}`);
}
}

View File

@ -55,7 +55,7 @@ jobs:
if: |
(startsWith(github.event.comment.body, '/bot run') ||
startsWith(github.event.comment.body, '/bot kill')) && contains(
fromJson('["zekunf-nv"]'),
fromJson('["nv-fastkernels-cicd", "zekunf-nv", "hwu36", "IonThruster", "thakkarV", "d-k-b", "mihir-awatramani", "fengxie", "vickiw973", "Junkai-Wu", "brandon-yujie-sun", "lijingticy22", "hongw-nv", "vikgupta-nv", "IwakuraRein", "depaulmillz", "jackkosaian", "itramble", "ccecka", "sxtyzhangzk", "hbarclay", "yzhaiustc", "x86vk", "sklevtsov-nvidia", "ANIKET-SHIVAM", "Shreya-gaur", "azhurkevich", "serifyesil", "richardmcai", "lsyyy666", "Ethan-Yan27", "XiaoSong9905", "shdetect", "keithzzzzz"]'),
github.actor)
steps:
- name: Check if comment is issued by authorized person

View File

@ -2,14 +2,222 @@
# CUTLASS 4.x
## [4.3.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-10-20)
### CuTe DSL
* Debuggability improvements:
- Supported source location tracking for DSL APIs
- Supported dumping PTX and CUBIN code
* More examples and notebooks to get started with CuTe DSL:
- [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
- Improved performance of elementwise kernel (https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/elementwise_apply.py):
+ Generalize code to handle list of input tensors
+ Generalize TV layout computation to handle different data types
- Demonstrate the new Pipeline APIs in [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py):
+ New Pipeline API `PipelineProducer` and `PipelineConsumer` to simplify code (no more explicit pipeline state management)
- Separate epilogue code for non-TMA and TMA implementation
+ Note that the updates simplifies the codes but existing APIs still work and are supported
- [Basic Blackwell SM100 GEMM with decent performance](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py)
+ Simple tutorial achieves 84% SOL performance with MNK 8K
- Reworked [elementwise add notebook](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb) with more details and detailed explanation about TV layout
+ Updated implementation to handle general data type and multiple inputs
+ Updated explanation for TV layout in simpler language
+ Added visualization of TV Layout with 3rd party utils
- [Benchmark and autotune demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb)
* More examples of authorizing peak-performance kernels:
- [Blackwell SM100 mixed-input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py)
- [Blackwell SM100 persistent blockwise dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py)
- [Blackwell SM100 persistent blockwise contiguous grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py)
- [Blackwell SM100 persistent blockwise masked grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py)
- [Blackwell SM100 fmha bwd](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha_bwd.py)
- [Blackwell SM100 mla](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mla.py)
- [Hopper SM90 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py)
- [Blackwell GeForce batched dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py)
- [Ampere HSTU Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/hstu_attention.py)
* API updates:
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
* Bug fixings and improvements
- Add mma_tiler_n=64 and mma_tiler_n=192 support in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py).
- Fixed ``TensorSSA.reduce`` to support static value as initial value
- Updated docstring for following APIs to be more concise and easier to understand:
- ``make_layout_tv``
- ``is_static``
- ``PipelineAsync``
- ``SmemAllocator``
- Fixed documentation for ``pipeline``, ``utils`` and ``cute.math``
### CUTLASS C++
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Add softmax skip correction.
- Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB.
- Fix a dead hang issue caused by early return warp.
* Add Ragged Contiguous Grouped gemm kernel in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/).
- This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations.
* Optimize group gemm kernels by enabling async TMA desc update.
* Support Blackwell SM100 convolution stream-K kernel.
- Unit tests: [fprop_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [dgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [wgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu).
* Add profiler support for Blackwell SM100 and SM120 blockscaled sparse kernels.
* Fix some kernel issues:
- Fix a race check issue of Blackwell SM103 kernels by adding missing elect one for prefetch barrier initialization.
- Allow user to directly specify the number of stages for Hopper sm90 mixed input gemm.
- Remove warnings caused by cuda vector type alignment setting in CUDA 13.
- Remove problematic `cutlass::int8_t` and replace it with `int8_t`.
* Fix some profiler issues:
- Add some missing reference kernels.
- Add calculation of scale factor A and B in function `bytes_with_problem_shape` of block scaled profiler.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 13.0U1.
## [4.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.1) (2025-09-22)
### CuTe DSL
* Bug fixings and improvements
- Fixed an issue when running DSL codes with cuda-python 13.0
- Fixed an issue when running inductor with DSL codes
- Fixed an issue with unexpected logging when running DSL codes in FlashInfer
- Fixed the issue reported in https://github.com/NVIDIA/cutlass/issues/2647
- Fixed an issue when conditional define of variables outside of dynamic control flow
### CUTLASS C++
* Bypass EVT for nosmem blockwise kernels on Blackwell.
* Rename cutlass/python/cutlass directory to cutlass/python/cutlass_cppgen.
## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15)
### CuTe DSL
* More Python versions are now supported for both x86-64 and aarch64, including
- Python 3.10, 3.11, 3.12, and 3.13
* Added new example and updated notebook to get started with CuTe DSL
- [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py)
- Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb)
+ Added a section for introducing the broadcast
* API updates
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
* Bug fixings and improvements
- Fixed ``cute.print_tensor`` for coordinate tensor
- Fixed `cute.print` for tuple of layouts
- Fixed frozen object is not properly updated after fully assigned in dynamic control flow
- Fixed assign tuple/list element in a dynamic control flow may cause compilation failure
- Improved error message when CUDA context is not initialized
- Improved docstring of congruent and weakly_congruent
### CUTLASS C++
* Support for Blackwell SM103 kernels for B300 GPUs.
- Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp)
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
- Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture:
- [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/).
- [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm).
* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM
- Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/).
* Support for Blackwell SM121 kernels for DGX Spark GPUs.
- Share the major codes with Blackwell SM120 kernels.
* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario.
- Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md).
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Add fused reduction kernel support for cutlass MLA.
- Add softmax skip correction.
- Support for GQA in FMHA backward kernel.
- Fix an issue where `get_unmasked_trip_count` may return a negative value.
- Fix an issue where mbarriers are initialized with a zero arrival count.
- Fix a corner case issue where the sequence length of q is not a multiple of tile_q.
- Remove tma padding for forward kernel inputs.
* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome.
* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell
- On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/).
- On Hopper, add K major scale factor support for SM90 blockwise kernels.
- On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size.
- On Hopper, grouped version supports the case when k = 0.
* Support for Blackwell SM100 fp4 gemv kernels.
- Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h).
- Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/)
* Support for Blackwell SM100 legacy mixed input GEMM kernels.
- Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp).
- Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp).
- Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/).
* Support for Blackwell SM100 cpasync kernel.
- Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp).
- Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp).
* Support Blackwell SM120 mixed input blockscaled grouped GEMM.
* Instantiating more Blackwell kernels in profiler.
- Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations.
- To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels.
- Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md).
* Fix some profiler issues:
- Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line.
- Fix some no output and timeout issues.
- Fix Pingpong Blockwise Hopper library generation.
* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110.
- For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs.
- For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid.
* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface.
- Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`.
- Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter.
- Added some support for running SM100 kernels via the Python interface.
* CuTe changes:
- Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/).
- Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support.
- Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels.
- Support fp16 accmulator for sm89 fp8 mma.
- Shorten `nullspace` implementation.
- Isolate and comment on `cosize` risky changes.
- Important documentation correction: `E<0,1> == 1@0@1`.
* Fix some kernel issues:
- Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers.
- Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel.
* Add following unit tests:
- [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu)
- [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu)
- [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu)
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 13.0U1.
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
### CuTe DSL
* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems!
* More examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
* API updates
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
### CUTLASS C++
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Add variable sequence length support for FMHA Backward kernel.
- Add varlen test support to Backward runner.
- Codes support empty batch sequences.
* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays.
* CuTe changes:
- Rewrite ArithTuple and ScaledBasis for robustness and clarity.
- Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX.
- Factor out `print_latex` and friends and rewrite.
- Factor out `print_svg` and friends and rewrite.
* Support Blackwell SM100 SIMT packed fp32x2 kernels.
* Support residual add for implicit gemm kernels.
* Various fixes for CUTLASS C++ Python interface's EVT tracer:
- Add verifier for sm90 to report the invalid input.
- When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges.
- Register operations of tanh, sigmoid, exp, gelu to the python ast frontend.
- Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback.
* Fix profiler bugs in exhaustive perf search.
- Fix incorrect cluster shape output issue when doing exhaustive search.
- Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders.
* Fix some profiler issues.
- Complete the reference for Blackwell blockwise gemm kernels.
- Fix incorrect regex logic for L1 test.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.9.
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
### CuTe DSL
* CuTe DSL, a Python DSL centered around CuTe's abstractions
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
- [DSL quick start](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with a new dedicated website](https://docs.nvidia.com/cutlass/latest)
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
@ -21,7 +229,7 @@
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
### CUTLASS C++
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
@ -35,7 +243,13 @@
- Added non-power-of-two tile sizes.
- Improved performance for K-major scale factors.
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
* Support LSE output in Blackwell SM100 FMHA Forward kernel in example 77.
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Support LSE output in FMHA Forward kernel.
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
- Enhance testing of variable sequence length.
- Disable B2B mode in MLA to simplify the sample.
- Clarify that `fmha_gen` sample only supports head dim 128.
- Fixes for split-kv output in MLA.
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
@ -103,7 +317,7 @@
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
* Support `void` as the D element in sm100 kernel epilogues.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.8U1.
@ -122,7 +336,7 @@
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
@ -160,11 +374,11 @@
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
* Documentation updates:
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
- A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html)
- A new [functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](https://docs.nvidia.com/cutlass/latest/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/latest/overview.html#target-architecture).
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
@ -177,7 +391,7 @@
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#cutlass-profiler).
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
- Optimal code generation with CUDA toolkit versions 12.6.
@ -191,12 +405,12 @@
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html).
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
@ -219,7 +433,7 @@
- Support for residual add (beta != 0) in convolution kernels.
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html).
- Better support for MSVC as a host compiler.
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
@ -227,7 +441,7 @@
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.

View File

@ -57,6 +57,10 @@ authors:
family-names: Blasig
email: dblasig@nvidia.com
affiliation: NVIDIA
- given-names: Aditya
family-names: Atluri
email: aatluri@nvidia.com
affiliation: NVIDIA
- given-names: Fengqi
family-names: Qiao
email: fqiao@nvidia.com

View File

@ -73,6 +73,16 @@ endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
# nvcc supports response files with --options-file but some tools like clangd
# might choke on it. Thus provide a way to control the use of this feature.
set(CUTLASS_CUDA_USE_RESPONSE_FILE ON CACHE BOOL "Enable CUDA response files for includes, libraries, and objects")
if(NOT CUTLASS_CUDA_USE_RESPONSE_FILE)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0)
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0)
endif()
if (CUDA_VERSION VERSION_LESS 11.3)
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.")
elseif (CUDA_VERSION VERSION_LESS 11.4)
@ -175,13 +185,25 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a 121 121a)
if (CUDA_VERSION VERSION_LESS 13.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
else()
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
endif()
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f 121f 103a 103f)
if (CUDA_VERSION VERSION_LESS 13.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
else()
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110f)
endif()
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
@ -288,17 +310,49 @@ if (KERNEL_FILTER_FILE)
set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE)
endif()
if (CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
get_filename_component(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" ABSOLUTE)
set(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" CACHE STRING "HEURISTICS FILE FULL PATH" FORCE)
endif()
set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list")
if(KERNEL_FILTER_FILE)
message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}")
endif()
if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
message(STATUS "Full path of heuristics problems file: ${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}")
if(DEFINED CUTLASS_NVMMH_URL)
message(STATUS "CUTLASS_NVVMH_URL is set. Fetching dependency")
include(FetchContent)
FetchContent_Declare(
nvmmh
URL ${CUTLASS_NVMMH_URL}
)
FetchContent_MakeAvailable(nvmmh)
FetchContent_GetProperties(nvmmh SOURCE_DIR nvmmh_dir)
set(CUTLASS_NVMMH_PATH "${nvmmh_dir}")
endif()
if(DEFINED CUTLASS_NVMMH_PATH)
message(STATUS "CUTLASS_NVMMH_PATH is set. Using package at: ${CUTLASS_NVMMH_PATH}")
set(CUTLASS_NVMMH_PY_DIR "${CUTLASS_NVMMH_PATH}/python/")
set(ENV{CUTLASS_NVMMH_SO_PATH} "${CUTLASS_NVMMH_PATH}/lib/libnvMatmulHeuristics.so")
endif()
endif()
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.")
set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).")
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 and SM100 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
endif()
################################################################################
@ -350,6 +404,10 @@ if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
endif()
if (CUTLASS_NVCC_ARCHS MATCHES 110f)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
endif()
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
#
@ -428,8 +486,6 @@ endif()
#
###################################################################################################
# Warnings-as-error exceptions and warning suppressions for Clang builds
if (CUTLASS_CLANG_HOST_COMPILE)
@ -704,9 +760,16 @@ target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
)
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
)
endif()
install(
DIRECTORY
${CUTLASS_INCLUDE_DIR}/
@ -751,9 +814,9 @@ if(NOT WIN32)
# Add common library search paths so executables and libraries can load and run
# without LD_LIBRARY_PATH being set.
link_libraries(
"-Wl,-rpath,'$ORIGIN'"
"-Wl,-rpath,'$ORIGIN/../lib64'"
"-Wl,-rpath,'$ORIGIN/../lib'"
"-Wl,-rpath,'$$ORIGIN'"
"-Wl,-rpath,'$$ORIGIN/../lib64'"
"-Wl,-rpath,'$$ORIGIN/../lib'"
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'"
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'"
${CMAKE_DL_LIBS}
@ -881,7 +944,7 @@ function(cutlass_add_executable_tests NAME TARGET)
install(
FILES ${__RESULT_CACHE_FILE}
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}/
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
)
endif()
@ -1009,7 +1072,7 @@ function(cutlass_generate_profiler_tests NAME)
install(
FILES ${CUTLASS_PROFILER_REGRESSION_LIST_FILE}
DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass/
DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass
RENAME profiler_regressions.csv
)

174
README.md
View File

@ -1,9 +1,9 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 4.0.0
# CUTLASS 4.3.0
_CUTLASS 4.0.0 - May 2025_
_CUTLASS 4.3.0 - Oct 2025_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@ -40,64 +40,77 @@ designs, and bringing optimized solutions into production.
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
To get started quickly - please refer :
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html).
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html).
# What's New in CUTLASS 4.0
# What's New in CUTLASS 4.3
### CuTe DSL
* CuTe DSL, a Python DSL centered around CuTe's abstractions
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
## CuTe DSL
* Debuggability improvements:
- Supported source location tracking for DSL APIs
- Supported dumping PTX and CUBIN code
* More examples and notebooks to get started with CuTe DSL:
- [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
- Improved performance of elementwise kernel (https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/elementwise_apply.py):
+ Generalize code to handle list of input tensors
+ Generalize TV layout computation to handle different data types
- Demonstrate the new Pipeline APIs in [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py):
+ New Pipeline API `PipelineProducer` and `PipelineConsumer` to simplify code (no more explicit pipeline state management)
- Separate epilogue code for non-TMA and TMA implementation
+ Note that the updates simplifies the codes but existing APIs still work and are supported
- [Basic Blackwell SM100 GEMM with decent performance](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py)
+ Simple tutorial achieves 84% SOL performance with MNK 8K
- Reworked [elementwise add notebook](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb) with more details and detailed explanation about TV layout
+ Updated implementation to handle general data type and multiple inputs
+ Updated explanation for TV layout in simpler language
+ Added visualization of TV Layout with 3rd party utils
- [Benchmark and autotune demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb)
* More examples of authorizing peak-performance kernels:
- [Blackwell SM100 mixed-input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py)
- [Blackwell SM100 persistent blockwise dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py)
- [Blackwell SM100 persistent blockwise contiguous grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py)
- [Blackwell SM100 persistent blockwise masked grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py)
- [Blackwell SM100 fmha bwd](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha_bwd.py)
- [Blackwell SM100 mla](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mla.py)
- [Hopper SM90 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py)
- [Blackwell GeForce batched dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py)
- [Ampere HSTU Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/hstu_attention.py)
* API updates:
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
* Bug fixings and improvements
- Add mma_tiler_n=64 and mma_tiler_n=192 support in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py).
- Fixed ``TensorSSA.reduce`` to support static value as initial value
- Updated docstring for following APIs to be more concise and easier to understand:
- ``make_layout_tv``
- ``is_static``
- ``PipelineAsync``
- ``SmemAllocator``
- Fixed documentation for ``pipeline``, ``utils`` and ``cute.math``
### CUTLASS C++
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
- For example:
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
- Added non-power-of-two tile sizes.
- Improved performance for K-major scale factors.
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
* Support LSE output in Blackwell SM100 FMHA Forward kernel in example 77.
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
* Fix profiler issues which cause no output or not supported error for some kernels.
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
* CuTe changes:
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.9.
## CUTLASS C++
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
- Add softmax skip correction.
- Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB.
- Fix a dead hang issue caused by early return warp.
* Add Ragged Contiguous Grouped gemm kernel in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/).
- This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations.
* Optimize group gemm kernels by enabling async TMA desc update.
* Support Blackwell SM100 convolution stream-K kernel.
- Unit tests: [fprop_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [dgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [wgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu).
* Add profiler support for Blackwell SM100 and SM120 blockscaled sparse kernels.
* Fix some kernel issues:
- Fix a race check issue of Blackwell SM103 kernels by adding missing elect one for prefetch barrier initialization.
- Allow user to directly specify the number of stages for Hopper sm90 mixed input gemm.
- Remove warnings caused by cuda vector type alignment setting in CUDA 13.
- Remove problematic `cutlass::int8_t` and replace it with `int8_t`.
* Fix some profiler issues:
- Add some missing reference kernels.
- Add calculation of scale factor A and B in function `bytes_with_problem_shape` of block scaled profiler.
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
CUTLASS team is working on a fix.
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/latest/CHANGELOG.html) for details of all past releases and updates.**
# Performance
@ -139,7 +152,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
This greatly simplifies the design and improves code composability and readability.
More documentation specific to CuTe can be found in its
[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
[dedicated documentation directory](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html).
# Compatibility
@ -186,7 +199,10 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|NVIDIA B300 Tensor Core GPU |10.3|13.0|
|NVIDIA DRIVE Thor |11.0|13.0|
|NVIDIA GeForce RTX 50x0 series |12.0|12.8|
|NVIDIA DGX Spark |12.1|13.0|
## Target Architecture
@ -218,11 +234,11 @@ cmake .. -DCUTLASS_NVCC_ARCHS="100a"
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
products has a different compute capability than the one underpinning
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
NVIDIA Blackwell GeForce RTX 50 series GPUs (SM120). As a result, kernels
compiled for Blackwell SM100 architecture with arch conditional features
(using `sm100a`) are not compatible with RTX 50 series GPUs.
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html)
for details on which kernels require which target architectures.
# Documentation
@ -230,22 +246,22 @@ for details on which kernels require which target architectures.
CUTLASS is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
- [Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
- [Functionality](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
- [Terminology](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html) - describes terms used in the code
- [Programming Guidelines](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
- [Tile Iterators](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) - command-line driven profiling application
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
kernels in the same stream, and how it is used in CUTLASS.
# Resources
@ -265,7 +281,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
paths.
CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.
@ -310,7 +326,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
and template concepts defined in the CUTLASS project.
A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
[CUTLASS documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/code_organization.html), but several main components are summarized below.
## CUTLASS Template Library
@ -384,7 +400,7 @@ tools/
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html).
# Performance Profiling
@ -600,9 +616,9 @@ reference_device: Passed
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html)
# About

View File

@ -36,6 +36,7 @@ set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
function(cutlass_generate_kernel_filter_and_testlist_files)
set(options)
@ -65,7 +66,12 @@ endfunction()
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
set(PROFILER_ARCH_LIST 100a 100f 103a 120a 120f 121a)
if (CUDA_VERSION VERSION_LESS 13.0)
list(APPEND PROFILER_ARCH_LIST 101a 101f)
else()
list(APPEND PROFILER_ARCH_LIST 110a 110f)
endif()
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")

View File

@ -45,7 +45,7 @@
cutlass::half_t
This is a numeric type implementing IEEE half-precision quantities. It is functional in host
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables harware-accelerated
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables hardware-accelerated
numeric conversion on x86-64 CPUs support F16C extensions. In device code, all available
hardware is used to implement conversion and numeric operations.

View File

@ -243,10 +243,11 @@ cudaError_t run_batched_gemm(bool use_array) {
const char* gemm_desc = use_array ? "array" : "strided batched";
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
// Arbitrary problem size
// Arbitrary matrix shape
int const m = 520;
int const n = 219;
int const k = 129;
int const batch_count = 17;
// A, B are non-transpose, column major

View File

@ -64,7 +64,7 @@ ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutla
ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not
enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do
that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB
to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C
to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C
which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the
data type of output ElementOutput (int32_t), the number of elements per vector memory access (16),
data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X +

View File

@ -64,7 +64,7 @@ ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t)
(int32_t). Communicating just the data type is not enough. As the data is laid out linearly in
memory, we have to convey the layout of matrices. We do that by initializing template variable
LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row
major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel.
major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel.
We initialize template variable EpilogueOp, which takes the data type of output ElementOutput
(int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t)
and data type of computation of linear combination (alpha * X + beta * C).

View File

@ -66,7 +66,7 @@ ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
computation of linear combination (alpha * X + beta * C).

View File

@ -177,7 +177,7 @@ public:
if(args.split_k_mode == SplitKMode::kParallel) {
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
// The user needs to call a reduction operator to optain the final output tensor
// The user needs to call a reduction operator to obtain the final output tensor
workspace_bytes =
sizeof(ElementAccumulator) *
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *

View File

@ -153,7 +153,7 @@ struct Options {
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
<< " run in a single kernel. Each individual problem in the group is subject to the same constraints that non-grouped\n"
<< " back-to-back GEMMs are subject to.s"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"

View File

@ -248,7 +248,7 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// Epilogue params remain constant across all problems in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
@ -402,7 +402,7 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// Epilogue params remain constant across all problems in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
@ -434,7 +434,7 @@ struct B2bGemm {
// Only row-major outputs are currently supported, so no transpose is performed
}
/// Returns non-grouped paramaters to be used as input to the kernel-level
/// Returns non-grouped parameters to be used as input to the kernel-level
/// operator for the problem indicated by problem_visitor.
CUTLASS_HOST_DEVICE
Params to_single_params(const ProblemVisitor& problem_visitor) const {

View File

@ -560,7 +560,7 @@ struct DefaultB2bConv2dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
// multistage pipeline with interleaved layout.
template <
typename ElementA,

View File

@ -606,7 +606,7 @@ struct DefaultB2bConv2dFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
// multistage pipeline with interleaved layout.
/// Accumulator will be staged in shared memory.
template <

View File

@ -277,7 +277,7 @@ public:
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile
FragmentC0 const &src_accum, ///< source accumulator tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment

View File

@ -298,7 +298,7 @@ public:
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile
FragmentC0 const &src_accum, ///< source accumulator tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment

View File

@ -93,7 +93,7 @@ template <
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
/// Operation performed by GEMM
typename Operator,
/// Epilogue output operator
typename EpilogueOutputOp,

View File

@ -203,7 +203,7 @@ requires any memory for scratch space.
If yes, we reserve scratch space and pass it along
with other arguments to initialize the CUTLASS kernel.
After lauching the CUTLASS kernel, this example runs
After launching the CUTLASS kernel, this example runs
a reference convolution kernel (from CUTLASS utilities)
to check correctness.
*/

View File

@ -144,7 +144,7 @@ int run() {
// Construct Gemm ProblemSize with user defined output size
cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024};
// Stride factor shows the distance between two elements in the differnet dimensions. The
// Stride factor shows the distance between two elements in the different dimensions. The
// first data is the logical distance between two rows, the second is between two columns.
// CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory<Layout>::layout_factory
// to help to convert stride_factor to the two strides.

View File

@ -55,7 +55,7 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
// Define the overal warp-level problem shape
// Define the overall warp-level problem shape
int const kM = 27;
int const kN = 31;
int const kK = 17;

View File

@ -59,7 +59,7 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
// Define the overal warp-level problem shape
// Define the overall warp-level problem shape
int const kM = 14;
int const kN = 27;
int const kK = 17;

View File

@ -659,7 +659,7 @@ struct Testbed {
}
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
int64_t bytes = cutlass::bits_to_bytes(
int64_t bytes = cutlass::bits_to_bytes<int64_t>(
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
options.problem_size.m() * options.problem_size.n());

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
// This example fuses gather before GEMM and scatter after GEMM into the same
// GEMM kernel. Gather and scatter operation is controled by an index vector
// GEMM kernel. Gather and scatter operation is controlled by an index vector
// to select rows or columns from A, B, C or D matrices.
//
// Suppose, all matrices are column major. The pseudo code of the fused kernel

View File

@ -87,7 +87,7 @@ public:
using ElementLayernormCompute = ElementLayernormCompute_;
using ThreadblockShape = ThreadblockShape_;
// Pre-processing has ensured the layout equivelent to RowMajor
// Pre-processing has ensured the layout equivalent to RowMajor
using Layout = cutlass::layout::RowMajor;
using TensorVariance = TensorRef<ElementVariance, Layout>;

View File

@ -33,8 +33,8 @@
computing reference permutations of 4/5D tensors when source data is column-major.
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/coord.h"

View File

@ -87,7 +87,7 @@ parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorC32RSK32"],

View File

@ -86,7 +86,7 @@ parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],

View File

@ -40,14 +40,12 @@
Note that in general the fragment passed to the OutputOp could
span multiple rows but it does not happen with the configurations we have
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"

View File

@ -42,12 +42,10 @@
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"

View File

@ -55,7 +55,7 @@
```
In practice, and for numerical stability reasons,
we also substract the maximum so far (`mi`) before doing
we also subtract the maximum so far (`mi`) before doing
the exponential. When we encounter new keys, the maximum
used to compute O so far (`m_prime`) can differ from the
current maximum, so we update O before accumulating with

View File

@ -55,7 +55,7 @@
```
In practice, and for numerical stability reasons,
we also substract the maximum so far (`mi`) before doing
we also subtract the maximum so far (`mi`) before doing
the exponential. When we encounter new keys, the maximum
used to compute O so far (`m_prime`) can differ from the
current maximum, so we update O before accumulating with

View File

@ -31,7 +31,7 @@
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instanciate to run a GEMM with various parameters (see
datastructures to instantiate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instantiation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
@ -83,7 +83,7 @@ template <
typename InstructionShape,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
/// Operation performed by GEMM
typename Operator,
typename Enable_ = void>
struct FindDefaultMma {

View File

@ -522,7 +522,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
// For API compatibility with MmaMultistageFromSharedMemory
// but not supported as it worsens perf: older gpus < sm80 don't
// support async tranfers and have to waste registers
// support async transfers and have to waste registers
CUTLASS_DEVICE
void set_prologue_done(bool value) {}
CUTLASS_DEVICE

View File

@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Instanciates the right WarpIterator to read from shared memory
\brief Instantiates the right WarpIterator to read from shared memory
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
data dumped with `B2bGemm::accumToSmem`.
*/

View File

@ -86,7 +86,7 @@ namespace threadblock {
/// To be efficient, this assumes the iterator will be dereferenced and advanced
/// at least once outside any looping structure to minimize integer arithmetic.
///
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to
/// dereferencing the iterator.
///
///

View File

@ -49,7 +49,7 @@
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
for this example:
a_rows - Rows in the sparse matrix.
a_cols - Colums in the sparse matrix.
a_cols - Columns in the sparse matrix.
a_ell_blocksize - Size of the ELL-Blocks.
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)

View File

@ -38,10 +38,8 @@
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/cutlass.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"

View File

@ -153,7 +153,7 @@ class gen_device:
warp_M_tile = 32
# Determine maxmimum N_tile
# Determine maximum N_tile
Max_Ntile = 0
for layer in self.fuse_gemm_info:
n_tile = layer['mnk'][1]

View File

@ -76,9 +76,9 @@ class gen_verify:
)
def get_params(self, declartion = True):
def get_params(self, declaration = True):
code = ""
if declartion:
if declaration:
for param in self.params:
code += param[0] + " " + param[1] + ";\n"

View File

@ -64,8 +64,8 @@ def write_2_headfile(filename, file_dir, string):
with open(file_dir + filename, 'w') as f:
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
def var_idx(varaiable, index):
return varaiable + str(index)
def var_idx(variable, index):
return variable + str(index)
def list_2_string(input_list, ):

View File

@ -37,12 +37,10 @@
*/
#pragma once
#include <cuda/std/cassert>
#include "cutlass/array.h"
#include CUDA_STD_HEADER(cassert)
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/tensor_coord.h"

View File

@ -78,7 +78,7 @@
a single default value.
CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of
the collective is specified via the schedule tags that corresond to the underlying collective's
the collective is specified via the schedule tags that correspond to the underlying collective's
dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto`
are special cases of these schedules that allow the builder to also decide the dispatch policy for you,
therefore letting the builder pick the collective specialization.

View File

@ -425,7 +425,7 @@ int main(int argc, char const **args) {
// Pipeline Depth to be used i.e number of A, B buffers in shared memory
constexpr int PipelineStages = 8;
// Let's choose a Warp-Specialized Mainloop implemention which uses TMA
// Let's choose a Warp-Specialized Mainloop implementation which uses TMA
// Note : This requires / assumes the tensors to be 16B aligned
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape,
cutlass::gemm::KernelTmaWarpSpecialized>;

View File

@ -32,7 +32,7 @@
\brief Example of a Hopper gather+GEMM+scatter kernel fusion.
This example fuses gather before GEMM and scatter after GEMM into the same
GEMM kernel. Gather and scatter operation is controled by an index vector
GEMM kernel. Gather and scatter operation is controlled by an index vector
to select rows or columns from A, B, C or D matrices.
Gather/scatter operations are always performed along a strided dimension

View File

@ -65,7 +65,7 @@
The approach relies on two things:
- The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the
flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details).
- The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
- The hardware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
(almost) arbitrary strides, which can be used to represent a permuted view of the data.
In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags.

View File

@ -188,7 +188,7 @@ Running this example on an RTX 3080Ti prints the following performance numbers (
```
$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.
Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.
Allocating tensors ... done.
Initializing data ... done.

View File

@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel
capable of operating on both affine and gather/scatter tensors.
This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off
@ -284,7 +284,7 @@ int ampere_gather_scatter_conv_fprop(
int
main(int argc, char const** argv) {
cutlass::CommandLine cmd(argc, argv);
std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n";
std::cout << "Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.\n\n";
if (cmd.check_cmd_line_flag("help")) {
std::cout
<< "Options:\n"

View File

@ -291,7 +291,7 @@ struct Options {
// Post-process the problem sizes
bin_problems();
// Initalize alpha array
// Initialize alpha array
randomize_alpha_ptr_array(cmd);
}

View File

@ -26,7 +26,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 90a)
cutlass_example_add_executable(
65_distributed_gemm
65_distributed_gemm.cu
)
endif()

View File

@ -129,7 +129,7 @@ using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_confi
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
@ -358,7 +358,7 @@ void initialize(const Options<RasterOrderOptions> &options) {
// Layout SFA and SFB represent logically broadcasting data in CuTe.
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
// indecies in the tensor map to the same offset.
// indices in the tensor map to the same offset.
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));

View File

@ -132,12 +132,12 @@ constexpr int ScaleGranularityK = 128;
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<

View File

@ -145,7 +145,7 @@ using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularity
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
@ -402,12 +402,37 @@ void initialize(const OptionType &options) {
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
// If the current group's matrix has size 0, set the pointer to nullptr
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
ptr_A_host.at(i) = nullptr;
} else {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
}
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
ptr_B_host.at(i) = nullptr;
} else {
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
}
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
ptr_C_host.at(i) = nullptr;
} else {
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
}
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
ptr_D_host.at(i) = nullptr;
} else {
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
ptr_blockscale_A_host.at(i) = nullptr;
} else {
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
}
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
ptr_blockscale_B_host.at(i) = nullptr;
} else {
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
}
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
@ -546,10 +571,10 @@ bool verify(const OptionType &options) {
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
@ -598,11 +623,7 @@ bool verify(const OptionType &options) {
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
unused_t, // Aux
unused_t, // valpha
unused_t // vbeta
decltype(D)
> epilogue_params;
epilogue_params.C = C;
@ -639,6 +660,24 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
allocate(options);
initialize(options);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
@ -671,8 +710,7 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
}
// Run profiling loop
if (options.iterations > 0)
{
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
@ -686,25 +724,6 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
fflush(stdout);

View File

@ -132,8 +132,7 @@ using ElementCompute = float; // E
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
static constexpr int ScaleGranularityM = 1;
@ -142,13 +141,13 @@ static constexpr int ScaleGranularityK = 128;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
@ -407,12 +406,37 @@ void initialize(const OptionType &options) {
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
// If the current group's matrix has size 0, set the pointer to nullptr
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
ptr_A_host.at(i) = nullptr;
} else {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
}
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
ptr_B_host.at(i) = nullptr;
} else {
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
}
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
ptr_C_host.at(i) = nullptr;
} else {
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
}
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
ptr_D_host.at(i) = nullptr;
} else {
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
ptr_blockscale_A_host.at(i) = nullptr;
} else {
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
}
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
ptr_blockscale_B_host.at(i) = nullptr;
} else {
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
}
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
@ -551,10 +575,10 @@ bool verify(const OptionType &options) {
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
@ -637,10 +661,27 @@ bool verify(const OptionType &options) {
template <typename OptionType>
int run(OptionType &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
@ -695,27 +736,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
ScaleMsPerTile,
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
std::cout << " GBPS: " << result.gbps << std::endl;
fflush(stdout);
}
return 0;
@ -766,8 +790,8 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
std::cout << "Running tests with host problem shapes:" << std::endl;
run(options, true);
std::cout << "Running tests without host problem shapes:" << std::endl;
run(options, false);

View File

@ -44,6 +44,9 @@ set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0)
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
set(TEST_K_16B_ALIGNED --m=256 --n=512 --k=960 --groups=10 --iterations=0)
set(TEST_K_16B_ALIGNED_LARGE_GROUP --m=256 --n=512 --k=960 --groups=512 --iterations=0)
cutlass_example_add_executable(
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
@ -58,6 +61,8 @@ cutlass_example_add_executable(
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_K_16B_ALIGNED
TEST_K_16B_ALIGNED_LARGE_GROUP
)
# MSVC will fail to compile this example with the following error:

View File

@ -111,14 +111,14 @@ struct Options {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
if (m < 0) {
m = m_alignment * (rand() % (64 * alignment / m_alignment));
}
if (n < 1) {
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
if (n < 0) {
n = n_alignment * (rand() % (64 * alignment / n_alignment));
}
if (k < 1) {
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
if (k < 0) {
k = k_alignment * (rand() % (32 * alignment / k_alignment));
}
problem_sizes_after_alignment_host.push_back({m, n, k});
problem_sizes_host.push_back({m, n, k});

View File

@ -454,11 +454,12 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -640,11 +640,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
}
//
// Parse options

View File

@ -33,7 +33,7 @@ set(TEST_SWIZZLE_2 --swizzle=2)
set(TEST_SWIZZLE_5 --swizzle=5)
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
cutlass_example_add_executable(
70_blackwell_fp16_gemm
70_blackwell_fp16_gemm.cu

View File

@ -449,9 +449,9 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}

View File

@ -27,7 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
cutlass_example_add_executable(
71_blackwell_gemm_with_collective_builder
71_blackwell_gemm_with_collective_builder.cu

View File

@ -40,10 +40,10 @@
Similar to 70_blackwell_gemm, this kernel leverages:
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Usage:
@ -116,10 +116,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ArchTag, OperatorClass,
MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
@ -190,13 +190,7 @@ cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_referen
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
return cute::recast_ptr<T>(ptr);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -329,7 +323,7 @@ bool initialize_block(
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
@ -413,7 +407,7 @@ bool verify(const Options &options) {
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
@ -514,13 +508,13 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}
}
//
// Parse options

View File

@ -39,10 +39,10 @@
1. Blockscaled tcgen05.mma instructions.
2. Per-SM memory called Tensor Memory (TMEM)
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Usage:
@ -129,13 +129,13 @@ constexpr int OutputSFVectorSize = InputSFVectorSize;
// With BlockScaleFactor generation.
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
OutputSFVectorSize,
ElementD,
ElementCompute,
ElementD,
ElementCompute,
ElementSFD, LayoutSFDTag,
ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ArchTag, OperatorClass,
MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
@ -219,13 +219,7 @@ cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_N
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
return cute::recast_ptr<T>(ptr);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -358,7 +352,7 @@ bool initialize_block(
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
@ -456,7 +450,7 @@ bool verify(const Options &options) {
decltype(tensor_B), // TensorB
decltype(tensor_SFB) // TensorSfB
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
@ -569,11 +563,11 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}

View File

@ -41,10 +41,10 @@
1. Blockscaled tcgen05.mma instructions.
2. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
which allows us to decouple the execution of MMA and epilogue into separate warps.
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
Usage:
@ -117,10 +117,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ArchTag, OperatorClass,
MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
@ -191,13 +191,7 @@ cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_referen
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
return cute::recast_ptr<T>(ptr);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -330,7 +324,7 @@ bool initialize_block(
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
@ -414,7 +408,7 @@ bool verify(const Options &options) {
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
cutlass::reference::host::GettBlockScalingEpilogueParams<
ElementAccumulator, // ElementScalar
ElementAccumulator, // ElementAccumulator
@ -515,14 +509,14 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!(props.major == 10 && props.minor == 0)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}
//
// Parse options
//

View File

@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
cutlass_example_add_executable(
72a_blackwell_nvfp4_bf16_gemm
72a_blackwell_nvfp4_bf16_gemm.cu

View File

@ -28,7 +28,7 @@
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
cutlass_example_add_executable(
73_blackwell_gemm_preferred_cluster
blackwell_gemm_preferred_cluster.cu

View File

@ -513,7 +513,7 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}

View File

@ -29,9 +29,9 @@
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
74_blackwell_gemm_streamk
blackwell_gemm_streamk.cu
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
cutlass_example_add_executable(
74_blackwell_gemm_streamk
blackwell_gemm_streamk.cu
)
endif()

View File

@ -61,7 +61,7 @@
# Heuristic mode with deterministic reduction
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic
# Stream-K mode with determinsitic reduction
# Stream-K mode with deterministic reduction
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic
# Split-K mode with a splitting factor of 2 and deterministic reduction
@ -556,10 +556,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -762,9 +762,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 10 && props.minor == 0)) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}

View File

@ -130,16 +130,15 @@ constexpr int OutputSFVectorSize = 16;
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
cutlass::epilogue::thread::SiLu,
OutputSFVectorSize,
ElementD,
ElementAccumulator,
ElementD,
ElementAccumulator,
ElementSFD,
LayoutC,
ElementC>;
// Core kernel configurations
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
// Runtime Cluster Shape
@ -159,7 +158,7 @@ struct MMA2SMConfig {
};
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass,
ArchTag, OperatorClass,
typename MMA1SMConfig::MmaTileShape, ClusterShape,
Shape<_128,_64>,
ElementAccumulator, ElementAccumulator,
@ -169,7 +168,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
// , FusionOperation // Enable for SF Output
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass,
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
@ -187,7 +186,7 @@ using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass,
ArchTag, OperatorClass,
typename MMA2SMConfig::MmaTileShape, ClusterShape,
Shape<_128,_64>,
ElementAccumulator, ElementAccumulator,
@ -197,13 +196,13 @@ using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::Collective
// , FusionOperation // Enable for SF Output
>::CollectiveOp;
using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass,
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
typename MMA2SMConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
static_cast<int>(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>,
typename MMA2SMConfig::KernelSchedule
>::CollectiveOp;
using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal<
@ -222,7 +221,7 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutS
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
OutputSFVectorSize,
OutputSFVectorSize,
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
>;
@ -233,7 +232,7 @@ using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<LayoutSFA> layout_SFA_host;
std::vector<LayoutSFA> layout_SFB_host;
std::vector<LayoutSFB> layout_SFB_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
@ -287,13 +286,7 @@ cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
template <typename T>
auto make_iterator(T* ptr) {
using namespace cute;
if constexpr (cute::is_subbyte_v<T>) {
return subbyte_iterator<T>(ptr);
}
else {
return ptr;
}
return cute::recast_ptr<T>(ptr);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -529,7 +522,7 @@ bool initialize_block(
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
@ -785,9 +778,9 @@ bool verify(const Options &options) {
decltype(tensor_SFA),
decltype(tensor_B),
decltype(tensor_SFB)
>
>
mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C);
auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D);
@ -856,8 +849,8 @@ int run(Options &options, bool host_problem_shapes_available = true)
}
}
else {
std::cout << " Verfication is turned off for this run." << std::endl;
}
std::cout << " Verification is turned off for this run." << std::endl;
}
// Run profiling loop
if (options.iterations > 0)
@ -903,9 +896,8 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (!(props.major == 10 && props.minor == 0)) {
std::cerr
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
return 0;
}
@ -933,7 +925,7 @@ int main(int argc, char const **args) {
std::cout << "Running kernel with 1SM MMA config:" << std::endl;
run<Gemm1SM>(options, false /*host_problem_shapes_available*/);
std::cout << "Running kernel with 2SM MMA config:" << std::endl;
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
#endif
return 0;

View File

@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
if("100a" IN_LIST CUTLASS_NVCC_ARCHS)
cutlass_example_add_executable(
75_blackwell_grouped_gemm
75_blackwell_grouped_gemm.cu

View File

@ -36,7 +36,7 @@
APIs on NVIDIA Blackwell SM100 architecture.
The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example:
Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
Xformed Activation (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
where in terms of GEMM perspective,
Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation
@ -504,10 +504,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -36,7 +36,7 @@
APIs on NVIDIA Blackwell SM100 architecture.
The basic computation logic of fprop convolution kernel is, take 3D convolution as an example:
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK)
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Activation (NZPQK)
where in terms of GEMM perspective,
Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation
@ -504,10 +504,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -36,7 +36,7 @@
APIs on NVIDIA Blackwell SM100 architecture.
The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example:
Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
Xformed Activation (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
where in terms of GEMM perspective,
Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter
@ -500,10 +500,19 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
if (__CUDACC_VER_MAJOR__ < 13) {
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0;
}
}
else {
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
return 0;
}
}
//
// Parse options
//

View File

@ -126,6 +126,7 @@ struct Options {
bool verbose = false;
bool causal = false;
bool causal_q_begin = true;
bool residual = false;
bool varlen = false;
bool persistent = false;
@ -266,6 +267,8 @@ struct Options {
std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
std::string causal_type;
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
if (mask == "no" || mask == "") {
causal = residual = false;
if (varlen) {
@ -275,6 +278,11 @@ struct Options {
else if (mask == "causal") {
residual = false;
causal = true;
if(causal_type == "qend") {
causal_q_begin = false;
} else {
causal_q_begin = true;
}
}
else if (mask == "residual") {
residual = true;
@ -313,6 +321,7 @@ struct Options {
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --causal-type=<qbegin|qend> Causal mask type\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
@ -410,16 +419,16 @@ struct FwdRunner {
using ElementAccumulatorPV = float;
using ElementOut = cutlass::half_t;
// Q K D (B H)
// Q K D ((H_R, H_K) B)
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D ((H_R, H_K), B)
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D ((H_R, H_K), B)
using StrideV = StrideK;
using StrideO = StrideQ;
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q ((H_R, H_K), B)
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
@ -505,8 +514,12 @@ struct FwdRunner {
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
auto [Q, K, D, HB] = problem_shape;
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
@ -598,8 +611,8 @@ struct FwdRunner {
ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size);
@ -656,9 +669,9 @@ struct FwdRunner {
}
auto buffer_init_fn = [&](auto& buffer) {
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
buffer.block_Q.reset(size(shape_QO));
buffer.block_K.reset(size(shape_KV));
buffer.block_V.reset(size(shape_KV));
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
buffer.block_LSE.reset(size(shape_LSE));
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
@ -853,7 +866,7 @@ struct FwdRunner {
flops *= static_cast<double>(size<1>(problem_shape));
flops *= static_cast<double>(size<3,1>(problem_shape));
}
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
flops *= static_cast<double>(size<2>(problem_shape));
flops *= static_cast<double>(size<3,0>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
@ -888,11 +901,18 @@ struct FwdRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
@ -1067,7 +1087,11 @@ int main_single(int argc, char const **args) {
auto with_mask = [&](auto fn) {
if (options.causal) {
fn(CausalMask{});
if(options.causal_q_begin) {
fn(CausalMask{});
} else {
fn(CausalMask<false>{});
}
}
else if (options.residual) {
fn(ResidualMask{});
@ -1093,7 +1117,7 @@ int main_single(int argc, char const **args) {
});
#endif
return 0;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -1101,8 +1125,6 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -1129,7 +1151,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return result;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -32,7 +32,7 @@
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
This example showcases the use of CUTLASS to build backward fused
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
multi-head attention (FMHA) collectives from existing CUTLASS collectives targeting
the NVIDIA Blackwell architecture.
Background and motivation
@ -114,12 +114,17 @@ struct Options {
int h_k = 1;
int q = 1024;
int k = 1024;
std::vector<int> varlen_q;
std::vector<int> varlen_k;
int d = 128;
int d_vo = 128;
int iterations = 3;
bool verify = false;
bool verbose = false;
bool causal = false;
bool residual = false;
bool varlen = false;
int sm_count = 0;
std::string kernel_filter;
@ -174,16 +179,82 @@ struct Options {
}
cmd.get_cmd_line_argument("d", d, defaults.d);
cmd.get_cmd_line_argument("d_vo", d_vo, d);
cmd.get_cmd_line_argument("h", h, -1);
if (h == -1) h = 2048 / d;
cmd.get_cmd_line_argument("h_k", h_k, -1);
if (h_k == -1) h_k = h;
varlen = cmd.check_cmd_line_flag("varlen");
cmd.get_cmd_line_argument("q", q, -1);
cmd.get_cmd_line_argument("k", k, -1);
cmd.get_cmd_line_argument("b", b, -1);
std::string varlen_q_str;
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
std::string varlen_k_str;
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
if (varlen && ! varlen_q_str.empty()) {
varlen_q.clear();
while (! varlen_q_str.empty()) {
size_t pos = varlen_q_str.find(':');
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_q_str = varlen_q_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_q.size());
}
if (b != static_cast<int>(varlen_q.size())) {
std::cout << "Error: Invalid --varlen-q length\n";
std::exit(-1);
}
int new_q = 0;
for (auto elem : varlen_q) {
new_q += elem;
}
if (q != -1) {
std::cout << "Error: Can't provide --q and --varlen-q\n";
std::exit(-1);
}
q = new_q;
}
if (varlen && ! varlen_k_str.empty()) {
varlen_k.clear();
while (! varlen_k_str.empty()) {
size_t pos = varlen_k_str.find(':');
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
if (pos == std::string::npos) {
break;
}
varlen_k_str = varlen_k_str.substr(pos + 1);
}
if (b == -1) {
b = static_cast<int>(varlen_k.size());
}
if (b != static_cast<int>(varlen_k.size())) {
std::cout << " Error: Invalid --varlen-k length\n";
std::exit(-1);
}
int new_k = 0;
for (auto elem : varlen_k) {
new_k += elem;
}
if (k != -1) {
std::cout << "Error: Can't provide --k and --varlen-k\n";
std::exit(-1);
}
k = new_k;
}
if (q == -1) q = k;
if (k == -1) k = q;
if (q == -1 && k == -1) q = k = defaults.q;
cmd.get_cmd_line_argument("b", b, -1);
if (b == -1) b = 16384 / k;
if (b == 0) b = 1;
@ -195,9 +266,15 @@ struct Options {
if (mask == "causal") {
causal = true;
}
else if (mask == "residual") {
residual = true;
}
else {
causal = defaults.causal;
}
if (varlen) {
residual = true;
}
skip_reference = cmd.check_cmd_line_flag("skip-reference");
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
@ -224,13 +301,22 @@ struct Options {
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the B extent\n"
<< " --h=<int> Sets the H extent\n"
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
<< " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn"
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
<< " --d=<int> Sets the D extent\n"
<< " --d_vo=<int> Sets the D_VO extent\n"
<< " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|causal> Enables masking\n"
<< " --mask=<no|residual|causal> Enables masking\n"
<< " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n"
<< " with the last batch sized to make it fit\n"
<< " implies at least residual masking for correctness\n"
<< " --sm-count Sets SM count rather than querying it\n"
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
<< "\n";
@ -307,6 +393,8 @@ struct ExampleResult {
///////////////////////////////////////////////////////////////////////////////////////////////////
template<
bool kIsVarlen,
bool kIsMla,
class TileShape,
class DispatchPolicy,
class ActiveMask,
@ -321,23 +409,24 @@ struct BwdRunner {
#endif
using ElementAccumulator = float;
// Q K D (H B)
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
// Q K D D_VO ((H_R, H_K) B)
using ProblemShape = std::conditional_t<
kIsVarlen,
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
>;
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
using StrideQ = TensorStride;
using StrideK = TensorStride;
using StrideV = TensorStride;
using StrideO = TensorStride;
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
// Backwards specific
using StrideDQ = TensorStride;
using StrideDK = TensorStride;
using StrideDV = TensorStride;
using StrideDO = TensorStride;
using StrideDQ = StrideQ;
using StrideDK = StrideK;
using StrideDV = StrideV;
using StrideDO = StrideO;
//
// Data members
@ -363,6 +452,9 @@ struct BwdRunner {
DeviceAllocation<Element> block_O;
DeviceAllocation<ElementAccumulator> block_LSE;
DeviceAllocation<int> block_cumulative_seqlen_q;
DeviceAllocation<int> block_cumulative_seqlen_kv;
DeviceAllocation<Element> block_dQ;
DeviceAllocation<Element> block_dK;
DeviceAllocation<Element> block_dV;
@ -375,47 +467,19 @@ struct BwdRunner {
//
// Methods
//
bool verify(const ProblemShapeType& problem_shape) {
auto [Q, K, D, HB] = problem_shape;
bool verify(const ProblemShape& problem_shape) {
auto [Q, K, D, D_VO, HB] = problem_shape;
auto [H, B] = HB;
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,2,3>(problem_shape),
stride_O);
// keep going here! (this might be better in cursor)
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,3>(problem_shape),
stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
select<0,2,3>(problem_shape),
stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
select<1,2,3>(problem_shape),
stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
select<1,2,3>(problem_shape),
stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
select<0,2,3>(problem_shape),
stride_dO);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE);
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ);
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK);
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV);
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO);
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
@ -459,22 +523,94 @@ struct BwdRunner {
return passed_dQ && passed_dK && passed_dV;
}
auto initialize_problem_shape(Options const& options) {
int h_r = options.h / options.h_k;
assert(options.h % options.h_k == 0);
if constexpr (kIsVarlen) {
int num_batches = options.b;
// generate Q as --b times
// gaussian (--Q, --Q / 2) sampled positive
// track cumulative
std::mt19937 rng(0x202305151552ull);
std::normal_distribution<double> dist_q(options.q, options.q / 2);
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
auto generate_positive_int = [](auto& dist, auto& gen) {
// "0" is a valid value we test here
return std::max(0, static_cast<int>(dist(gen)));
};
std::vector<int> cumulative_seqlen_q = {0};
std::vector<int> cumulative_seqlen_kv = {0};
int total_seqlen_q = 0;
int total_seqlen_kv = 0;
int max_seqlen_q = 0;
int max_seqlen_kv = 0;
const bool kVarlenSame = false;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
kVarlenSame ? options.q :
generate_positive_int(dist_q, rng);
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
kVarlenSame ? options.k :
generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
}
block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
ProblemShape problem_shape{
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
options.d, options.d_vo, {{h_r, options.h_k}, options.b}
};
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1));
return cute::make_tuple(problem_shape, tensor_shape);
}
else {
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}};
return cute::make_tuple(problem_shape, problem_shape);
}
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
auto [Q, K, D, HB] = problem_shape;
ProblemShape initialize(Options const& options) {
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
auto [Q, K, D, D_VO, HB] = tensor_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto shape_QO = select<0,2,3>(problem_shape);
auto shape_KV = select<1,2,3>(problem_shape);
auto shape_LSE = select<0,3>(problem_shape);
// for varlen, Q == total_Q, K == total_K, B = 1
// but in problem_shape, they've got to be max_Q/max_K, and B = B
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
stride_V = stride_K;
stride_O = stride_Q;
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
auto shape_Q = make_shape(Q, D, HB);
auto shape_K = make_shape(K, D, HB);
auto shape_V = make_shape(K, D_VO, HB);
auto shape_O = make_shape(Q, D_VO, HB);
auto shape_LSE = make_shape(Q, HB);
stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K));
stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K));
stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K));
stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
stride_dQ = stride_Q;
stride_dK = stride_K;
@ -485,58 +621,72 @@ struct BwdRunner {
return size(make_shape(1ull, shape));
};
block_Q.reset(lsize(shape_QO));
block_K.reset(lsize(shape_KV));
block_V.reset(lsize(shape_KV));
block_O.reset(lsize(shape_QO));
auto size_K = lsize(K * D * H_K * B);
auto size_V = lsize(K * D_VO * H_K * B);
block_Q.reset(lsize(shape_Q));
block_K.reset(size_K);
block_V.reset(size_V);
block_O.reset(lsize(shape_O));
block_LSE.reset(lsize(shape_LSE));
block_dQ.reset(lsize(shape_QO));
block_dK.reset(lsize(shape_KV));
block_dV.reset(lsize(shape_KV));
block_dO.reset(lsize(shape_QO));
block_dQ.reset(lsize(shape_Q));
block_dK.reset(size_K);
block_dV.reset(size_V);
block_dO.reset(lsize(shape_O));
block_ref_dQ.reset(lsize(shape_QO));
block_ref_dK.reset(lsize(shape_KV));
block_ref_dV.reset(lsize(shape_KV));
block_ref_dQ.reset(lsize(shape_Q));
block_ref_dK.reset(size_K);
block_ref_dV.reset(size_V);
initialize_block(block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v);
initialize_block(block_dO, seed + 2020, options.init_style_do);
initialize_block(block_dQ, seed + 2030, InitStyle::kOne);
initialize_block(block_dK, seed + 2031, InitStyle::kOne);
initialize_block(block_dV, seed + 2032, InitStyle::kOne);
initialize_block(block_ref_dQ, seed + 2033);
initialize_block(block_ref_dK, seed + 2034);
initialize_block(block_ref_dV, seed + 2035);
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
select<0,2,3>(problem_shape),
select<0,2,4>(problem_shape),
stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
select<1,2,3>(problem_shape),
select<1,2,4>(problem_shape),
stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
select<1,2,3>(problem_shape),
select<1,3,4>(problem_shape),
stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
select<0,2,3>(problem_shape),
select<0,3,4>(problem_shape),
stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
select<0,3>(problem_shape),
select<0,4>(problem_shape),
stride_LSE);
if (! options.skip_reference) {
if (not options.skip_reference) {
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
}
return problem_shape;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
initialize(problem_shape, options);
auto problem_shape = initialize(options);
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
ExampleResult example_result;
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;
typename Operation::Arguments arguments{
problem_shape,
block_Q.get(), stride_Q,
@ -554,8 +704,6 @@ struct BwdRunner {
Operation op;
ExampleResult example_result;
example_result.smem_size = Operation::Kernel::SharedStorageSize;
size_t workspace_size = 0;
@ -650,12 +798,13 @@ struct BwdRunner {
runtime_ms /= static_cast<float>(options.iterations);
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask<false>> || std::is_same_v<ActiveMask, CausalForBackwardMask<true>> ? 0.5 : 1.0);
flops *= static_cast<double>(get<0>(problem_shape));
flops *= static_cast<double>(get<1>(problem_shape));
flops *= static_cast<double>(get<2>(problem_shape));
flops *= static_cast<double>(get<3,0>(problem_shape));
flops *= static_cast<double>(get<3,1>(problem_shape));
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
flops *= static_cast<double>(get<4,0,0>(problem_shape));
flops *= static_cast<double>(get<4,0,1>(problem_shape));
flops *= static_cast<double>(get<4,1>(problem_shape));
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
example_result.tflops_tc_s = tflops_s;
example_result.runtime_ms = runtime_ms;
@ -688,11 +837,18 @@ struct BwdRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
@ -706,19 +862,33 @@ void print_result(const std::string& description, ExampleResult result, bool ver
struct KernelCoop {};
///////////////////////////////////////////////////////////////////////////////////////////////////
template<class Fn>
auto dispatch_bool(bool value, Fn fn) {
if (value) {
return fn(std::true_type{});
}
else {
return fn(std::false_type{});
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////
template<class Mask>
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, false,decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
};
using HeadDim = _64;
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -726,14 +896,31 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
template<class Mask>
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, false, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
};
using HeadDim = _128;
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
}
template<class Mask>
void run_bwd_mla_192(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
dispatch_bool(options.varlen, [&](auto is_varlen) {
BwdRunner<decltype(is_varlen)::value, true, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
auto result = runner.run(options, hw_info);
print_result(name, result, options.verbose);
});
};
using HeadDim = _192;
run(Shape<_64, _128, HeadDim, _128>{}, KernelCoop{}, "tma");
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -797,13 +984,16 @@ int main_single(int argc, char const **args) {
hw_info.sm_count = options.sm_count;
}
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
std::cout << "#SM " << hw_info.sm_count << std::endl;
auto with_causal = [&](auto fn) {
if (options.causal) {
fn(CausalMask{});
fn(CausalForBackwardMask{});
}
else if (options.residual) {
fn(ResidualMaskForBackward{});
}
else {
fn(NoMask{});
@ -811,19 +1001,22 @@ int main_single(int argc, char const **args) {
};
with_causal([&](auto fusion) {
if (options.d <= 64) {
if (options.d <= 64 && options.d_vo == options.d) {
run_bwd_64(fusion, options, hw_info);
}
else if (options.d <= 128) {
else if (options.d <= 128 && options.d_vo == options.d) {
run_bwd_128(fusion, options, hw_info);
}
else if (options.d == 192 && options.d_vo == 128) {
run_bwd_mla_192(fusion, options, hw_info);
}
else {
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
}
});
#endif
return 0;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -831,8 +1024,6 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -859,7 +1050,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return result;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -689,11 +689,18 @@ struct ExampleRunner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
if (result.supported && ! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
@ -781,12 +788,17 @@ int main_single(int argc, char const **args) {
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
)
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
if (options.d == 128) {
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
}
else {
std::cout << "Head Dimension != 128 is not supported for the fmha_gen example\n";
}
#endif
return 0;
@ -797,8 +809,6 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -825,7 +835,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return result;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -80,6 +80,7 @@ struct Options {
int iterations = 3;
bool verify = false;
bool verbose = false;
bool is_fused_reduction = false;
int sm_count = 0;
@ -139,9 +140,12 @@ struct Options {
if (b == 0) b = 1;
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
if (split_kv == 0) {
split_kv = 1;
}
cmd.get_cmd_line_argument("page", page, defaults.page);
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
is_var_split_kv = cmd.check_cmd_line_flag("var_split_kv");
if (page == -1) {
is_var_split_kv = false;
}
@ -149,6 +153,10 @@ struct Options {
if (is_var_split_kv == true) {
split_kv = max_split_kv;
}
is_fused_reduction = cmd.check_cmd_line_flag("fuse_reduction");
if (split_kv == 1) {
is_fused_reduction = false;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose");
@ -176,6 +184,8 @@ struct Options {
<< " --iterations=<int> Benchmarking iterations\n"
<< " --spread=<float> Relative spread away from K for paging\n"
<< " --split_kv=<int> Split KV factor\n"
<< " --fused_reduction Fuse the reduction operation\n"
<< " --var_split_kv Use varying split KV factor\n"
<< " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n"
<< " --sm-count Sets SM count rather than querying it\n"
@ -391,11 +401,7 @@ struct Runner {
// Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0;
double mean_diff = 0;
#ifdef B2B
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
#else
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
#endif
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) {
@ -404,7 +410,6 @@ struct Runner {
}
bool passed_LSE = true;
#ifndef B2B
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
@ -412,7 +417,6 @@ struct Runner {
std::cerr << "failed LSE: max diff " << max_diff
<< " mean " << mean_diff << std::endl;
}
#endif
return passed_O && passed_LSE;
}
@ -520,7 +524,8 @@ struct Runner {
stride_LSE},
hw_info,
options.split_kv,
options.is_var_split_kv ? block_split_kv.get() : nullptr
options.is_var_split_kv ? block_split_kv.get() : nullptr,
options.is_fused_reduction
};
if (options.split_kv < 0 && !options.is_var_split_kv) {
Operation::set_split_kv(arguments);
@ -678,11 +683,18 @@ struct Runner {
///////////////////////////////////////////////////////////////////////////////////////////////////
int main_result = 0;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to print a description of the example run and its result
void print_result(const std::string& description, ExampleResult result, bool verbose) {
std::ios fmt(nullptr);
fmt.copyfmt(std::cout);
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
if (! result.passed) {
main_result = -1;
}
std::cout << std::setw(32) << std::left << description;
std::cout.copyfmt(fmt);
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
@ -723,13 +735,17 @@ void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
if (!options.is_fused_reduction || options.split_kv == 1) {
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
}
#elif FP16
name += " fp16";
// Persistent Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
// Individual Tile Scheduler
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
if (!options.is_fused_reduction || options.split_kv == 1) {
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
}
#endif
}
@ -806,8 +822,6 @@ int main_single(int argc, char const **args) {
int main(int argc, char const **args) {
std::vector<std::string> full_arguments(args, args + argc);
int result = 0;
bool recursed = false;
for (size_t i = 1; i < full_arguments.size(); i++) {
if (full_arguments[i].find(',') != std::string::npos) {
@ -834,7 +848,7 @@ int main(int argc, char const **args) {
main_single(argc, args);
}
return result;
return main_result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@ -33,40 +33,84 @@ set_property(
77_blackwell_fmha_gen.cu
77_blackwell_mla.cu
77_blackwell_fmha_bwd.cu
77_blackwell_mla_fwd.cu
PROPERTY
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
)
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
set(TEST_VARLEN_00 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_01 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_02 --verify --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_03 --verify --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_04 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_05 --verify --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_06 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_VARLEN_07 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_VARLEN_08 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_VARLEN_09 --verify --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_VARLEN_10 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
set(TEST_VARLEN_11 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
set(TEST_VARLEN_12 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_VARLEN_00 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_01 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_02 --verify --varlen --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_03 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_04 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_05 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_VARLEN_06 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_VARLEN_07 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_VARLEN_08 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_VARLEN_09 --verify --varlen --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_VARLEN_10 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_02 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_03 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_MLA_FWD_VARLEN_04 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_MLA_FWD_VARLEN_05 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
set(TEST_MLA_FWD_VARLEN_06 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
set(TEST_MLA_FWD_VARLEN_07 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
set(TEST_MLA_FWD_VARLEN_08 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
set(TEST_MLA_FWD_VARLEN_09 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
set(TEST_MLA_FWD_VARLEN_10 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=2:3 --varlen-k=2:5)
set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=11:10 --varlen-k=13:10)
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
set(TEST_MLA_FWD_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
set(TEST_MLA_FWD_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify)
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
set(TEST_MLA_LARGE_SPLIT_KV --verify --split_kv=20 --page=128)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
@ -78,7 +122,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_fmha.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_CAUSAL
TEST_CAUSAL_00
TEST_CAUSAL_01
TEST_VARLEN
TEST_HDIM64
TEST_GQA
@ -97,6 +142,14 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
TEST_VARLEN_15
TEST_VARLEN_16
TEST_VARLEN_17
TEST_VARLEN_18
TEST_VARLEN_19
TEST_VARLEN_20
TEST_VARLEN_21
TEST_VARLEN_22
)
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
@ -107,7 +160,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
TEST_COMMAND_OPTIONS
TEST_GEN_BASIC
TEST_GEN_VARLEN
TEST_GEN_HDIM64
# TEST_GEN_HDIM64
TEST_GEN_GQA
TEST_GEN_REMAP
TEST_GEN_CACHEONLY
@ -120,6 +173,9 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
TEST_MLA_LARGE_SPLIT_KV
TEST_MLA_SEP_REDUCTION
TEST_MLA_FUSE_REDUCTION
)
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
@ -130,46 +186,79 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
TEST_MLA_LARGE_SPLIT_KV
TEST_MLA_SEP_REDUCTION
TEST_MLA_FUSE_REDUCTION
)
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_mla_b2b_2sm_${PREC}
77_blackwell_mla.cu
TEST_COMMAND_OPTIONS
TEST_MLA_BASIC
)
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_${PREC}
77_blackwell_fmha_bwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
TEST_VARLEN
# NOTE: bwd doesn't support GQA yet, --h_k will just get ignored in these tests
TEST_VARLEN_00
TEST_VARLEN_01
TEST_VARLEN_02
TEST_VARLEN_03
TEST_VARLEN_04
TEST_VARLEN_05
TEST_VARLEN_06
TEST_VARLEN_07
TEST_VARLEN_08
TEST_VARLEN_09
TEST_VARLEN_10
TEST_VARLEN_11
TEST_VARLEN_12
TEST_VARLEN_13
TEST_VARLEN_14
TEST_BWD_MLA_BASIC
TEST_BWD_MLA_VARLEN
)
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
cutlass_example_add_executable(
77_blackwell_fmha_bwd_sat_${PREC}
77_blackwell_fmha_bwd.cu
77_blackwell_mla_fwd_${PREC}
77_blackwell_mla_fwd.cu
TEST_COMMAND_OPTIONS
TEST_BASIC
# TEST_GEN_VARLEN
TEST_GEN_HDIM64
# TEST_GEN_GQA
# TEST_GEN_REMAP
# TEST_GEN_CACHEONLY)
TEST_CAUSAL_00
TEST_VARLEN
TEST_HDIM64
TEST_GQA
TEST_MLA_FWD_VARLEN_00
TEST_MLA_FWD_VARLEN_01
TEST_MLA_FWD_VARLEN_02
TEST_MLA_FWD_VARLEN_03
TEST_MLA_FWD_VARLEN_04
TEST_MLA_FWD_VARLEN_05
TEST_MLA_FWD_VARLEN_06
TEST_MLA_FWD_VARLEN_07
TEST_MLA_FWD_VARLEN_08
TEST_MLA_FWD_VARLEN_09
TEST_MLA_FWD_VARLEN_10
TEST_MLA_FWD_VARLEN_11
TEST_MLA_FWD_VARLEN_12
TEST_MLA_FWD_VARLEN_13
TEST_MLA_FWD_VARLEN_14
TEST_MLA_FWD_VARLEN_15
TEST_MLA_FWD_VARLEN_16
TEST_MLA_FWD_VARLEN_17
TEST_MLA_FWD_VARLEN_18
TEST_MLA_FWD_VARLEN_19
TEST_MLA_FWD_VARLEN_20
TEST_MLA_FWD_VARLEN_21
TEST_MLA_FWD_VARLEN_22
)
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
target_compile_options(77_blackwell_mla_fwd_${PREC} PRIVATE -Xptxas -v)
endforeach()
# Add a target that builds all examples
@ -183,9 +272,9 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_mla_b2b_2sm_fp8
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_mla_fwd_fp8
77_blackwell_mla_fwd_fp16
)
endif()

View File

@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
The kernel and collective layer are then formulated to be fmha-specific.
@ -37,13 +37,19 @@ There are three kernels to compute backwards:
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
## MLA Blackwell Backward
The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample.
`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape.
# MLA Inference for Blackwell
This sample provides code for fused multi-head latent attention inference in
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
It supports fp16, bf16, and fp8 input and output types.
To accomodate the large output accumulator due to the large latent head dimension,
To accommodate the large output accumulator due to the large latent head dimension,
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
@ -55,6 +61,14 @@ The approach of this implementation is to reuse the selection logic of the colle
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
# Changes
* 4.1.0: Enhanced testing of variable sequence length; disabled B2B mode in MLA
to simplify the sample, clarified that `fmha_gen` sample only supports head
dim 128.
* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

View File

@ -132,10 +132,68 @@ struct ResidualMask : NoMask {
}
};
struct ResidualMaskForBackward : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (! elem_less(pos, select<0,1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
};
// There are two ways to do causal if N_Q != N_K
// (1) The Q is at the beginning of the matrix
// (2) The Q is at the end of the matrix
template<bool kIsQBegin = true>
struct CausalMask : NoMask {
using Base = NoMask;
static constexpr bool IsQBegin = kIsQBegin;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
@ -146,8 +204,14 @@ struct CausalMask : NoMask {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
if constexpr (IsQBegin) {
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
@ -156,9 +220,14 @@ struct CausalMask : NoMask {
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
@ -171,6 +240,47 @@ struct CausalMask : NoMask {
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is the default setting.
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to set kIsQBegin=false
if constexpr (IsQBegin) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
} else {
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
}
};
template<bool kIsQBegin = true>
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
using Base = CausalMask<kIsQBegin>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
@ -186,10 +296,16 @@ struct CausalMask : NoMask {
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
int offset_q = 0;
if constexpr (!kIsQBegin) {
offset_q = get<1>(problem_size) - get<0>(problem_size);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}
}
@ -200,22 +316,23 @@ struct CausalMask : NoMask {
struct VariableLength {
int max_length;
int* cumulative_length = nullptr;
int total_length = -1;
CUTE_HOST_DEVICE operator int() const {
return max_length;
}
};
template<class T> struct is_variable_length : std::false_type {};
template<> struct is_variable_length<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
template<class T> struct is_variable_length_impl : std::false_type {};
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
template<class Shape, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Idx const& idx) {
return transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
@ -230,7 +347,7 @@ constexpr auto
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
auto new_shape = apply_variable_length(shape, idx);
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
if constexpr (is_variable_length_v<decltype(s)>) {
return cute::make_tuple(c, s.cumulative_length[idx]);
}
else {
@ -240,6 +357,30 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
return cute::make_tuple(new_shape, new_coord);
}
template<class Shape, class Coord>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
auto idx = back(back(coord));
auto result_shape = transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx];
}
else {
return _0{};
}
});
return cute::make_tuple(result_shape, result_offset);
}
} // namespace cutlass::fmha::collective
namespace cute {

View File

@ -42,7 +42,8 @@ template<
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_ // Q, B
class StrideLSE_, // Q, B
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -55,7 +56,11 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorage {
using SmemLayoutO = SmemLayoutO_;
@ -85,6 +90,19 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
StrideLSE dLSE;
};
// FMHA and MLA have different input ProblemShapes;
// get problem_shape_O according to the input ProblemShape.
template<class ProblemShape>
CUTLASS_DEVICE static constexpr
auto get_problem_shape_O (
ProblemShape const& problem_shape) {
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
} else {
return select<0,2,3>(problem_shape);
}
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
@ -93,7 +111,8 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
auto ptr_O = args.ptr_O;
StrideO dO = args.dO;
auto problem_shape_O = select<0,2,3>(problem_shape);
auto problem_shape_O = get_problem_shape_O(problem_shape);
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
@ -145,7 +164,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
int o0_index = 2 * get<0>(blk_coord);
int o1_index = 2 * get<0>(blk_coord) + 1;
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape));
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
// offset mode 0 by (max_length - real_length)
// offset mode 3,1 by cumulative_length + real_length
// the ptr is already offset by - max_length
@ -200,6 +219,11 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
tma_store_wait<0>();
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;

View File

@ -58,7 +58,9 @@ template<
// and referes to the two softmax warps
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
// (1, 2, 1) means they sit side by side (best for small Q / large K)
class ThreadShape = Shape<_2, _1, _1>
class ThreadShape = Shape<_2, _1, _1>,
// Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdMainloopTmaWarpspecialized {
@ -76,7 +78,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
using ClusterShape = Shape<_1, _1, _1>;
static const int Alignment = 128 / sizeof_bits_v<Element>;
@ -106,6 +108,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));
// Reuse shared memory for V and O.
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
@ -168,9 +172,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadKV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>), "K and V smem layouts must be of equal size");
static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size");
using Load = Sm100FmhaLoadTmaWarpspecialized<
Element, StrideQ, StrideK, StrideV,
@ -525,7 +530,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
@ -663,7 +668,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
pipeline_c.producer_acquire(pipeline_c_producer_state);
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
row_sum *= acc_scale;
// row_sum = sum(reg_S)
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
@ -929,13 +934,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
if (scale != 1.0f) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
}
}
copy_out(i);
@ -1004,11 +1012,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
// e^(scale * (old_max - new_max)
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O0));
bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
if (warp_do_correction) {
correction_rescale(scale, uint32_t(TmemAllocation::O0));
}
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
@ -1022,11 +1033,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
if (warp_do_correction) {
correction_rescale(scale, uint32_t(TmemAllocation::O1));
}
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
@ -1065,8 +1079,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
@ -1129,6 +1143,85 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective

View File

@ -86,10 +86,10 @@ struct Sm100FmhaGenMainloopWarpspecialized {
static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2;
static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{});
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
using ClusterShape = Shape<_1, _1, _1>;
static const int Alignment = 128 / sizeof_bits_v<Element>;
@ -187,7 +187,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
SmemLayoutQ, SmemLayoutK, SmemLayoutV,
PipelineQ, PipelineKV, TileShape, Mask
>;
struct Arguments {
typename Load::Arguments load;
@ -534,14 +534,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
@ -622,7 +622,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
const int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
@ -646,7 +646,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
@ -672,7 +672,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
pipeline_c.producer_acquire(pipeline_c_producer_state);
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
row_sum *= acc_scale;
// row_sum = sum(reg_S)
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
@ -700,7 +700,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
@ -781,7 +781,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
template<class Vector, class GTensor, class CTensor, class Shape, class Epilogue>
CUTLASS_DEVICE auto
correction_epilogue(
float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1,
float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1,
GTensor& gO, CTensor const& cO, Shape const& g_shape,
Epilogue const& epilogue) {
@ -794,13 +794,13 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOgO = mma.get_slice(0).partition_C(gO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
@ -812,7 +812,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0);
Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
@ -831,7 +831,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// loop:
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
@ -841,10 +841,10 @@ struct Sm100FmhaGenMainloopWarpspecialized {
Tensor tTMrO0 = make_tensor<ElementPV>(shape(tTMEM_LOADcO));
Tensor tTMrO1 = make_tensor<ElementPV>(shape(tTMEM_LOADcO));
copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0);
copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO0); j += 2) {
float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1));
@ -891,24 +891,24 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 32;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
@ -917,8 +917,8 @@ struct Sm100FmhaGenMainloopWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
@ -948,13 +948,16 @@ struct Sm100FmhaGenMainloopWarpspecialized {
}
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
if (scale != 1.0f) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
}
}
copy_out(i);
@ -981,7 +984,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
@ -1019,11 +1022,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
// e^(scale * (old_max - new_max)
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O0));
bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
if (warp_do_correction) {
correction_rescale(scale, uint32_t(TmemAllocation::O0));
}
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
@ -1037,11 +1043,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
if (warp_do_correction) {
correction_rescale(scale, uint32_t(TmemAllocation::O1));
}
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;

View File

@ -170,8 +170,8 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
auto tSgQ = thr_mma_qk.partition_A(gQ);
auto tScQ = thr_mma_qk.partition_A(cQ);
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
auto tiled_copy_q = make_cotiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},

View File

@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = problem_shape;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = problem_shape;
}
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);

View File

@ -0,0 +1,323 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape,
class OrderLoadEpilogue = cute::false_type
>
struct Sm100MlaFwdLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
}
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
// V1
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
// prefetch vi
cute::prefetch(params.tma_load_v, tVgV(_, k_index));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
// prefetch ki+1
if(mask_tile_count > 1) {
cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));
}
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Support the producer to acquire specific bytes of data.
*/
#pragma once
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
template <
int Stages_,
class ClusterShape = Shape<int,int,_1>,
class AtomThrShape_MNK_ = Shape<_1,_1,_1>
>
class PipelineTmaAsyncMla {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {
auto atom_thr_shape = AtomThrShape_MNK{};
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?
cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas
cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
auto cluster_layout = make_layout(cluster_shape);
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
auto cluster_layout = make_layout(cluster_shape);
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
public:
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape, mcast_direction);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape, mcast_direction);
}
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {
detail::pipeline_check_is_producer(params_.role);
if (barrier_token != BarrierStatus::WaitDone) {
empty_barrier_ptr_[stage].wait(phase);
}
if (params_.is_leader) {
full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);
}
#ifndef NDEBUG
if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {
asm volatile ("brkpt;\n" ::);
}
// Most likely you have elected more than one leader
if (params_.is_leader && (threadIdx.x % 32 != 0)) {
asm volatile ("brkpt;\n" ::);
}
#endif
}
CUTLASS_DEVICE
void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);
}
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void consumer_release(PipelineState state) {
consumer_release(state.index(), false);
}
private:
Impl impl_;
Params params_;
EmptyBarrier *empty_barrier_ptr_;
FullBarrier *full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notifed.
CUTLASS_DEVICE
void consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
}
else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
}
else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
}

View File

@ -39,6 +39,7 @@
#include "../device/fmha.hpp"
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
@ -50,35 +51,74 @@ namespace cutlass::fmha::device {
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<class Element, class ElementAccumulator, class TileShape, class Mask>
template<
class ProblemShape,
class Element,
class ElementAccumulator,
class TileShape,
bool IsMla,
class Mask
>
class Sm100FmhaBwd {
private:
template <typename T>
constexpr static auto to_bwd_shape(T shape) {
if constexpr (IsMla) { // remove GQA mode
constexpr int R = decltype(rank(shape))::value;
auto HB = get<R-1>(shape);
auto rest = take<0,R-1>(shape);
return append(rest, make_shape(size<0>(HB), get<1>(HB)));
}
else {
return shape;
}
}
template <typename T>
constexpr static auto to_bwd_stride(T stride) {
if constexpr (IsMla) { // remove GQA mode
constexpr int R = decltype(rank(stride))::value;
auto HB = get<R-1>(stride);
auto rest = take<0,R-1>(stride);
if constexpr (is_same_v<remove_cv_t<decltype(get<0,0>(HB))>, _0>) {
return append(rest, make_stride(get<0,1>(HB), get<1>(HB)));
}
else {
return append(rest, make_stride(get<0,0>(HB), get<1>(HB)));
}
}
else {
return stride;
}
}
public:
/// Argument structure: User API
struct Arguments {
// Q K D HB
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
// Q K D D_VO HB
ProblemShape problem_shape;
const Element* ptr_Q;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_Q;
const Element* ptr_K;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_K;
const Element* ptr_V;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_V;
const Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
cute::tuple<cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dO;
Element* ptr_dQ;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dQ;
Element* ptr_dK;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dK;
Element* ptr_dV;
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dV;
ElementAccumulator softmax_scale;
@ -86,15 +126,27 @@ public:
};
using OperationSumOdO = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
>;
using OperationConvert = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
>;
using Operation = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
using OperationNormal= cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
ProblemShape, Element, ElementAccumulator, TileShape, Mask
>
>;
using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{}));
using OperationMla = cutlass::fmha::device::FMHA<
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask
>
>;
using Operation = std::conditional_t<IsMla, OperationMla, OperationNormal>;
using Kernel = typename Operation::Kernel;
struct Params {
@ -113,15 +165,16 @@ private:
ElementAccumulator* sum_odo = nullptr,
ElementAccumulator* scaled_lse = nullptr) {
using namespace cute;
auto [Q, K, D, HB] = args.problem_size;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
auto log2_e = log2f(expf(1.0f));
return typename OperationSumOdO::Arguments {
args.problem_size,
args.problem_shape,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
sum_odo, stride_sum_OdO,
@ -133,16 +186,17 @@ private:
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
using namespace cute;
auto [Q, K, D, HB] = args.problem_size;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = HB;
auto [H_R, H_K] = H;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
return typename OperationConvert::Arguments {
args.problem_size,
args.problem_shape,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV,
@ -152,21 +206,22 @@ private:
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_sum_OdO = {},
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_scaled_lse = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_dQ = {}) {
return typename Operation::Arguments{
args.problem_size,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
scaled_lse, stride_scaled_lse,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ,
to_bwd_shape(args.problem_shape),
{ args.ptr_Q, to_bwd_stride(args.stride_Q),
args.ptr_K, to_bwd_stride(args.stride_K),
args.ptr_V, to_bwd_stride(args.stride_V),
args.ptr_dO, to_bwd_stride(args.stride_dO),
scaled_lse, to_bwd_stride(stride_scaled_lse),
sum_OdO, to_bwd_stride(stride_sum_OdO),
dQ_acc, to_bwd_stride(stride_dQ),
args.softmax_scale },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
args.ptr_dV, to_bwd_stride(args.stride_dV) },
args.hw_info
};
}
@ -199,10 +254,10 @@ public:
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
@ -219,10 +274,10 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
@ -248,10 +303,10 @@ public:
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [Q, K, D, HB] = args.problem_size;
auto [H, B] = HB;
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
auto [H, B] = product_each(HB);
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);

View File

@ -127,7 +127,11 @@ public:
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
if (args.is_fused_reduction && split_wave_aware > 1) {
args.split_kv = std::min(split_wave_aware, static_cast<int>(sm_count/2));
} else {
args.split_kv = split_wave_aware;
}
}
/// Determines whether the GEMM can execute the given problem.
@ -273,11 +277,33 @@ public:
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
auto [H, K, D, B] = params.fmha_params.problem_shape;
auto [D_latent, D_rope] = D;
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;
}
auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int);
uint8_t* ws = reinterpret_cast<uint8_t*>(params.fmha_params.epilogue.ptr_lse_exchange_buff);
result = cudaMemsetAsync(ws, 0, total_bytes, stream);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaMemsetAsync() returned error: "
<< cudaGetErrorString(result));
return Status::kErrorInternal;;
}
}
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
@ -298,7 +324,7 @@ public:
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);

View File

@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
// Swizzle Q tile and H tile to improve L2 cache hit rate,
// and launch the longest main loop first to keep most SMs busy.
struct CausalIndividualTileScheduler {
static constexpr int TileQ = 16;
static constexpr int TileH = 8;
static constexpr int TileSize = TileQ * TileH;
struct Params {
dim3 grid;
int tile_max_q;
FastDivmod divmod_tile_col;
FastDivmod divmod_tile_size;
FastDivmod divmod_tile_head;
};
bool valid_ = true;
Params params;
CUTLASS_DEVICE
CausalIndividualTileScheduler(Params const& params) : params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));
// gridDim.x must multiple of TileH
const int tile_col_count = grid.x / TileH;
const int tile_max_q = grid.y / TileQ * TileQ;
return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;
int tile_idx, tile_tail;
params.divmod_tile_size(tile_idx, tile_tail, block_idx);
int tile_row_idx, tile_col_idx;
params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);
int row_offset_in_tail, col_offset_in_tail;
params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);
const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;
const int col_idx = tile_col_idx * TileH + col_offset_in_tail;
// last q tile launch first
if(blockIdx.y >= params.tile_max_q) {
return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));
}
return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));
}
CUTLASS_DEVICE
CausalIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// Launch order: H Q B
struct CausalPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_h;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
return Params {
num_blocks,
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_h(block_decode, bidh, block_decode);
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
}
CUTLASS_DEVICE
CausalPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

Some files were not shown because too many files have changed in this diff Show More