Compare commits
70 Commits
thakkarv/4
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8afb19d904 | |||
| b2ca083d2b | |||
| b1d6e2c9b3 | |||
| e6e2cc29f5 | |||
| c6aeb9179c | |||
| 95a5ff14c0 | |||
| fb8b43ef05 | |||
| f874df19ac | |||
| 7a6d4ee099 | |||
| 2b8dff1f90 | |||
| fd0312ddf6 | |||
| 64579189ec | |||
| b234a8c024 | |||
| 74825181f2 | |||
| 8825e8be4f | |||
| 7817e47154 | |||
| 25ccb875b8 | |||
| 29c1ad704a | |||
| 57e3cfb47a | |||
| e7e0adddac | |||
| 6a35b4d22f | |||
| 56f0718a97 | |||
| 76c96b0be3 | |||
| d98e7bf7ce | |||
| b6ccf34aef | |||
| 2288c0c901 | |||
| b2dd65dc86 | |||
| 496654bf2c | |||
| 9ca7e877b2 | |||
| a49a78ffef | |||
| 11cad1f67b | |||
| 931359cec1 | |||
| 42e7c546c4 | |||
| ec18e8043b | |||
| 5b76420d6a | |||
| 19772cd63e | |||
| 052afcd314 | |||
| 86cf63e2d4 | |||
| a267d47f9b | |||
| 9e6ab77d27 | |||
| d0eada85a3 | |||
| 23139309e9 | |||
| 6dd13d4278 | |||
| 3b054767b3 | |||
| 6fb5e667c1 | |||
| 6c891db9f6 | |||
| da47886e34 | |||
| 26b7450023 | |||
| a39cf6b511 | |||
| f09045d660 | |||
| 84a27b3926 | |||
| e093b4f691 | |||
| 664c4f7b3e | |||
| 0e026982ce | |||
| 9a9a579714 | |||
| 51d730b8be | |||
| 6c0c8b7484 | |||
| e51efbfe18 | |||
| fd6cfe1ed0 | |||
| 9baa06dd57 | |||
| ebe98c549a | |||
| 9892624b66 | |||
| a1aaf2300a | |||
| b995f93317 | |||
| 889ff20648 | |||
| dc4817921e | |||
| 5c6bca0441 | |||
| c2ad7c5b20 | |||
| cc23f6d1e9 | |||
| 5a287538c2 |
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -1,23 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG]"
|
||||
labels: "? - Needs Triage, bug"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
38
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: Bug Report
|
||||
description: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG] "
|
||||
labels: ["? - Needs Triage", "bug"]
|
||||
assignees: []
|
||||
|
||||
body:
|
||||
- type: dropdown
|
||||
id: component
|
||||
attributes:
|
||||
label: Which component has the problem?
|
||||
options:
|
||||
- CuTe DSL
|
||||
- CUTLASS C++
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: bug-report
|
||||
attributes:
|
||||
label: Bug Report
|
||||
description: Please fill out all sections below
|
||||
value: |
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
validations:
|
||||
required: true
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for CUTLASS
|
||||
title: "[FEA]"
|
||||
labels: "? - Needs Triage, feature request"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
35
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
35
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
name: Feature Request
|
||||
description: Suggest an idea for CUTLASS
|
||||
title: "[FEA] "
|
||||
labels: ["? - Needs Triage", "feature request"]
|
||||
assignees: []
|
||||
|
||||
body:
|
||||
- type: dropdown
|
||||
id: component
|
||||
attributes:
|
||||
label: Which component requires the feature?
|
||||
options:
|
||||
- CuTe DSL
|
||||
- CUTLASS C++
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: feature-request
|
||||
attributes:
|
||||
label: Feature Request
|
||||
description: Please fill out all sections below
|
||||
value: |
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
validations:
|
||||
required: true
|
||||
51
.github/workflows/auto-label-issues.yml
vendored
Normal file
51
.github/workflows/auto-label-issues.yml
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
name: Auto Label Issues
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
add-labels:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- name: Add component label
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const issue = context.payload.issue;
|
||||
const body = issue.body || '';
|
||||
|
||||
// Parse the issue body to find the component selection
|
||||
// GitHub renders dropdown selections as "### {label}\n\n{selection}"
|
||||
// Check for both bug report and feature request dropdown labels
|
||||
const bugComponentMatch = body.match(/### Which component has the problem\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
|
||||
const featureComponentMatch = body.match(/### Which component requires the feature\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
|
||||
|
||||
const componentMatch = bugComponentMatch || featureComponentMatch;
|
||||
|
||||
if (componentMatch) {
|
||||
const component = componentMatch[1].trim();
|
||||
let label = '';
|
||||
|
||||
// Map component selections to labels
|
||||
switch(component) {
|
||||
case 'CuTe DSL':
|
||||
label = 'CuTe DSL';
|
||||
break;
|
||||
case 'CUTLASS C++':
|
||||
label = 'CUTLASS C++';
|
||||
break;
|
||||
}
|
||||
|
||||
if (label) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issue.number,
|
||||
labels: [label]
|
||||
});
|
||||
console.log(`Added label: ${label}`);
|
||||
}
|
||||
}
|
||||
2
.github/workflows/blossom-ci.yml
vendored
2
.github/workflows/blossom-ci.yml
vendored
@ -55,7 +55,7 @@ jobs:
|
||||
if: |
|
||||
(startsWith(github.event.comment.body, '/bot run') ||
|
||||
startsWith(github.event.comment.body, '/bot kill')) && contains(
|
||||
fromJson('["zekunf-nv"]'),
|
||||
fromJson('["nv-fastkernels-cicd", "zekunf-nv", "hwu36", "IonThruster", "thakkarV", "d-k-b", "mihir-awatramani", "fengxie", "vickiw973", "Junkai-Wu", "brandon-yujie-sun", "lijingticy22", "hongw-nv", "vikgupta-nv", "IwakuraRein", "depaulmillz", "jackkosaian", "itramble", "ccecka", "sxtyzhangzk", "hbarclay", "yzhaiustc", "x86vk", "sklevtsov-nvidia", "ANIKET-SHIVAM", "Shreya-gaur", "azhurkevich", "serifyesil", "richardmcai", "lsyyy666", "Ethan-Yan27", "XiaoSong9905", "shdetect", "keithzzzzz"]'),
|
||||
github.actor)
|
||||
steps:
|
||||
- name: Check if comment is issued by authorized person
|
||||
|
||||
277
CHANGELOG.md
277
CHANGELOG.md
@ -2,50 +2,263 @@
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.3.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-10-20)
|
||||
|
||||
### CuTe DSL
|
||||
* Debuggability improvements:
|
||||
- Supported source location tracking for DSL APIs
|
||||
- Supported dumping PTX and CUBIN code
|
||||
* More examples and notebooks to get started with CuTe DSL:
|
||||
- [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
|
||||
- Improved performance of elementwise kernel (https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/elementwise_apply.py):
|
||||
+ Generalize code to handle list of input tensors
|
||||
+ Generalize TV layout computation to handle different data types
|
||||
- Demonstrate the new Pipeline APIs in [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py):
|
||||
+ New Pipeline API `PipelineProducer` and `PipelineConsumer` to simplify code (no more explicit pipeline state management)
|
||||
- Separate epilogue code for non-TMA and TMA implementation
|
||||
+ Note that the updates simplifies the codes but existing APIs still work and are supported
|
||||
- [Basic Blackwell SM100 GEMM with decent performance](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py)
|
||||
+ Simple tutorial achieves 84% SOL performance with MNK 8K
|
||||
- Reworked [elementwise add notebook](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb) with more details and detailed explanation about TV layout
|
||||
+ Updated implementation to handle general data type and multiple inputs
|
||||
+ Updated explanation for TV layout in simpler language
|
||||
+ Added visualization of TV Layout with 3rd party utils
|
||||
- [Benchmark and autotune demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb)
|
||||
* More examples of authorizing peak-performance kernels:
|
||||
- [Blackwell SM100 mixed-input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py)
|
||||
- [Blackwell SM100 persistent blockwise dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py)
|
||||
- [Blackwell SM100 persistent blockwise contiguous grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py)
|
||||
- [Blackwell SM100 persistent blockwise masked grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py)
|
||||
- [Blackwell SM100 fmha bwd](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha_bwd.py)
|
||||
- [Blackwell SM100 mla](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mla.py)
|
||||
- [Hopper SM90 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py)
|
||||
- [Blackwell GeForce batched dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py)
|
||||
- [Ampere HSTU Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/hstu_attention.py)
|
||||
* API updates:
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
* Bug fixings and improvements
|
||||
- Add mma_tiler_n=64 and mma_tiler_n=192 support in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py).
|
||||
- Fixed ``TensorSSA.reduce`` to support static value as initial value
|
||||
- Updated docstring for following APIs to be more concise and easier to understand:
|
||||
- ``make_layout_tv``
|
||||
- ``is_static``
|
||||
- ``PipelineAsync``
|
||||
- ``SmemAllocator``
|
||||
- Fixed documentation for ``pipeline``, ``utils`` and ``cute.math``
|
||||
|
||||
### CUTLASS C++
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add softmax skip correction.
|
||||
- Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB.
|
||||
- Fix a dead hang issue caused by early return warp.
|
||||
* Add Ragged Contiguous Grouped gemm kernel in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/).
|
||||
- This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations.
|
||||
* Optimize group gemm kernels by enabling async TMA desc update.
|
||||
* Support Blackwell SM100 convolution stream-K kernel.
|
||||
- Unit tests: [fprop_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [dgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [wgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu).
|
||||
* Add profiler support for Blackwell SM100 and SM120 blockscaled sparse kernels.
|
||||
* Fix some kernel issues:
|
||||
- Fix a race check issue of Blackwell SM103 kernels by adding missing elect one for prefetch barrier initialization.
|
||||
- Allow user to directly specify the number of stages for Hopper sm90 mixed input gemm.
|
||||
- Remove warnings caused by cuda vector type alignment setting in CUDA 13.
|
||||
- Remove problematic `cutlass::int8_t` and replace it with `int8_t`.
|
||||
* Fix some profiler issues:
|
||||
- Add some missing reference kernels.
|
||||
- Add calculation of scale factor A and B in function `bytes_with_problem_shape` of block scaled profiler.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 13.0U1.
|
||||
|
||||
## [4.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.1) (2025-09-22)
|
||||
|
||||
### CuTe DSL
|
||||
* Bug fixings and improvements
|
||||
- Fixed an issue when running DSL codes with cuda-python 13.0
|
||||
- Fixed an issue when running inductor with DSL codes
|
||||
- Fixed an issue with unexpected logging when running DSL codes in FlashInfer
|
||||
- Fixed the issue reported in https://github.com/NVIDIA/cutlass/issues/2647
|
||||
- Fixed an issue when conditional define of variables outside of dynamic control flow
|
||||
|
||||
### CUTLASS C++
|
||||
* Bypass EVT for nosmem blockwise kernels on Blackwell.
|
||||
* Rename cutlass/python/cutlass directory to cutlass/python/cutlass_cppgen.
|
||||
|
||||
## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15)
|
||||
|
||||
### CuTe DSL
|
||||
* More Python versions are now supported for both x86-64 and aarch64, including
|
||||
- Python 3.10, 3.11, 3.12, and 3.13
|
||||
* Added new example and updated notebook to get started with CuTe DSL
|
||||
- [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py)
|
||||
- Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb)
|
||||
+ Added a section for introducing the broadcast
|
||||
* API updates
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
* Bug fixings and improvements
|
||||
- Fixed ``cute.print_tensor`` for coordinate tensor
|
||||
- Fixed `cute.print` for tuple of layouts
|
||||
- Fixed frozen object is not properly updated after fully assigned in dynamic control flow
|
||||
- Fixed assign tuple/list element in a dynamic control flow may cause compilation failure
|
||||
- Improved error message when CUDA context is not initialized
|
||||
- Improved docstring of congruent and weakly_congruent
|
||||
|
||||
### CUTLASS C++
|
||||
* Support for Blackwell SM103 kernels for B300 GPUs.
|
||||
- Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp)
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture:
|
||||
- [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/).
|
||||
- [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm).
|
||||
* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM
|
||||
- Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/).
|
||||
* Support for Blackwell SM121 kernels for DGX Spark GPUs.
|
||||
- Share the major codes with Blackwell SM120 kernels.
|
||||
* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario.
|
||||
- Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md).
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add fused reduction kernel support for cutlass MLA.
|
||||
- Add softmax skip correction.
|
||||
- Support for GQA in FMHA backward kernel.
|
||||
- Fix an issue where `get_unmasked_trip_count` may return a negative value.
|
||||
- Fix an issue where mbarriers are initialized with a zero arrival count.
|
||||
- Fix a corner case issue where the sequence length of q is not a multiple of tile_q.
|
||||
- Remove tma padding for forward kernel inputs.
|
||||
* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome.
|
||||
* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell
|
||||
- On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/).
|
||||
- On Hopper, add K major scale factor support for SM90 blockwise kernels.
|
||||
- On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size.
|
||||
- On Hopper, grouped version supports the case when k = 0.
|
||||
* Support for Blackwell SM100 fp4 gemv kernels.
|
||||
- Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h).
|
||||
- Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/)
|
||||
* Support for Blackwell SM100 legacy mixed input GEMM kernels.
|
||||
- Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp).
|
||||
- Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp).
|
||||
- Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/).
|
||||
* Support for Blackwell SM100 cpasync kernel.
|
||||
- Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp).
|
||||
- Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp).
|
||||
* Support Blackwell SM120 mixed input blockscaled grouped GEMM.
|
||||
* Instantiating more Blackwell kernels in profiler.
|
||||
- Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations.
|
||||
- To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels.
|
||||
- Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md).
|
||||
* Fix some profiler issues:
|
||||
- Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line.
|
||||
- Fix some no output and timeout issues.
|
||||
- Fix Pingpong Blockwise Hopper library generation.
|
||||
* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110.
|
||||
- For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs.
|
||||
- For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid.
|
||||
* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface.
|
||||
- Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`.
|
||||
- Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter.
|
||||
- Added some support for running SM100 kernels via the Python interface.
|
||||
* CuTe changes:
|
||||
- Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/).
|
||||
- Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support.
|
||||
- Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels.
|
||||
- Support fp16 accmulator for sm89 fp8 mma.
|
||||
- Shorten `nullspace` implementation.
|
||||
- Isolate and comment on `cosize` risky changes.
|
||||
- Important documentation correction: `E<0,1> == 1@0@1`.
|
||||
* Fix some kernel issues:
|
||||
- Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers.
|
||||
- Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel.
|
||||
* Add following unit tests:
|
||||
- [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu)
|
||||
- [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu)
|
||||
- [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu)
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 13.0U1.
|
||||
|
||||
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
|
||||
|
||||
### CuTe DSL
|
||||
* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems!
|
||||
* More examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
|
||||
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
|
||||
* API updates
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
|
||||
### CUTLASS C++
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add variable sequence length support for FMHA Backward kernel.
|
||||
- Add varlen test support to Backward runner.
|
||||
- Codes support empty batch sequences.
|
||||
* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays.
|
||||
* CuTe changes:
|
||||
- Rewrite ArithTuple and ScaledBasis for robustness and clarity.
|
||||
- Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX.
|
||||
- Factor out `print_latex` and friends and rewrite.
|
||||
- Factor out `print_svg` and friends and rewrite.
|
||||
* Support Blackwell SM100 SIMT packed fp32x2 kernels.
|
||||
* Support residual add for implicit gemm kernels.
|
||||
* Various fixes for CUTLASS C++ Python interface's EVT tracer:
|
||||
- Add verifier for sm90 to report the invalid input.
|
||||
- When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges.
|
||||
- Register operations of tanh, sigmoid, exp, gelu to the python ast frontend.
|
||||
- Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback.
|
||||
* Fix profiler bugs in exhaustive perf search.
|
||||
- Fix incorrect cluster shape output issue when doing exhaustive search.
|
||||
- Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders.
|
||||
* Fix some profiler issues.
|
||||
- Complete the reference for Blackwell blockwise gemm kernels.
|
||||
- Fix incorrect regex logic for L1 test.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
|
||||
|
||||
### CuTe DSL
|
||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
||||
- [DSL quick start](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html)
|
||||
- [DSL Overview](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html)
|
||||
* [Overhauled documentation with a new dedicated website](https://docs.nvidia.com/cutlass/latest)
|
||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
|
||||
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/jit_argument.py)
|
||||
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
|
||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||
* API updates
|
||||
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
|
||||
### CUTLASS C++
|
||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
|
||||
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
||||
- For example:
|
||||
|
||||
`(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
`(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
|
||||
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
|
||||
* Support LSE output in Blackwell FMHA Forward kernel in example 77.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
|
||||
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Support LSE output in FMHA Forward kernel.
|
||||
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
|
||||
- Enhance testing of variable sequence length.
|
||||
- Disable B2B mode in MLA to simplify the sample.
|
||||
- Clarify that `fmha_gen` sample only supports head dim 128.
|
||||
- Fixes for split-kv output in MLA.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
- Enable runtime datatype for Blackwell grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell grouped GEMM - raster_order, swizzle.
|
||||
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
|
||||
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
|
||||
* Add dynamic and preferred cluster support for convolution kernels.
|
||||
* Support for Blackwell SM120 blockwise dense gemm in cutlass core library, as well as cutlass profiler.
|
||||
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
|
||||
* Fix profiler issues which cause no output or not supported error for some kernels.
|
||||
* Optimization porting for BlockScaled collectives and kernel layers.
|
||||
* New [Hopper FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
|
||||
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
|
||||
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* CuTe changes:
|
||||
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
|
||||
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
|
||||
@ -104,7 +317,7 @@
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
||||
@ -123,7 +336,7 @@
|
||||
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
@ -161,11 +374,11 @@
|
||||
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
|
||||
- A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
|
||||
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html)
|
||||
- A new [functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](https://docs.nvidia.com/cutlass/latest/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/latest/overview.html#target-architecture).
|
||||
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
@ -178,7 +391,7 @@
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
@ -192,12 +405,12 @@
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
|
||||
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html).
|
||||
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
@ -220,7 +433,7 @@
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
@ -228,7 +441,7 @@
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
|
||||
@ -57,6 +57,10 @@ authors:
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aditya
|
||||
family-names: Atluri
|
||||
email: aatluri@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Fengqi
|
||||
family-names: Qiao
|
||||
email: fqiao@nvidia.com
|
||||
|
||||
@ -73,6 +73,16 @@ endif()
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
# nvcc supports response files with --options-file but some tools like clangd
|
||||
# might choke on it. Thus provide a way to control the use of this feature.
|
||||
set(CUTLASS_CUDA_USE_RESPONSE_FILE ON CACHE BOOL "Enable CUDA response files for includes, libraries, and objects")
|
||||
|
||||
if(NOT CUTLASS_CUDA_USE_RESPONSE_FILE)
|
||||
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0)
|
||||
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0)
|
||||
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 11.3)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.")
|
||||
elseif (CUDA_VERSION VERSION_LESS 11.4)
|
||||
@ -175,13 +185,25 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a 121 121a)
|
||||
if (CUDA_VERSION VERSION_LESS 13.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
|
||||
else()
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f 121f 103a 103f)
|
||||
if (CUDA_VERSION VERSION_LESS 13.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
|
||||
else()
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110f)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
@ -288,17 +310,49 @@ if (KERNEL_FILTER_FILE)
|
||||
set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
|
||||
get_filename_component(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" ABSOLUTE)
|
||||
set(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" CACHE STRING "HEURISTICS FILE FULL PATH" FORCE)
|
||||
endif()
|
||||
|
||||
set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list")
|
||||
|
||||
if(KERNEL_FILTER_FILE)
|
||||
message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}")
|
||||
endif()
|
||||
|
||||
if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
|
||||
message(STATUS "Full path of heuristics problems file: ${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}")
|
||||
if(DEFINED CUTLASS_NVMMH_URL)
|
||||
message(STATUS "CUTLASS_NVVMH_URL is set. Fetching dependency")
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
nvmmh
|
||||
URL ${CUTLASS_NVMMH_URL}
|
||||
)
|
||||
FetchContent_MakeAvailable(nvmmh)
|
||||
FetchContent_GetProperties(nvmmh SOURCE_DIR nvmmh_dir)
|
||||
set(CUTLASS_NVMMH_PATH "${nvmmh_dir}")
|
||||
endif()
|
||||
|
||||
if(DEFINED CUTLASS_NVMMH_PATH)
|
||||
message(STATUS "CUTLASS_NVMMH_PATH is set. Using package at: ${CUTLASS_NVMMH_PATH}")
|
||||
|
||||
set(CUTLASS_NVMMH_PY_DIR "${CUTLASS_NVMMH_PATH}/python/")
|
||||
set(ENV{CUTLASS_NVMMH_SO_PATH} "${CUTLASS_NVMMH_PATH}/lib/libnvMatmulHeuristics.so")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.")
|
||||
set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).")
|
||||
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
|
||||
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 and SM100 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
|
||||
|
||||
if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
|
||||
message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
@ -350,6 +404,10 @@ if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 110f)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
|
||||
|
||||
#
|
||||
@ -428,8 +486,6 @@ endif()
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
|
||||
|
||||
# Warnings-as-error exceptions and warning suppressions for Clang builds
|
||||
if (CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
@ -704,9 +760,16 @@ target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
|
||||
)
|
||||
|
||||
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
|
||||
)
|
||||
endif()
|
||||
|
||||
install(
|
||||
DIRECTORY
|
||||
${CUTLASS_INCLUDE_DIR}/
|
||||
@ -751,9 +814,9 @@ if(NOT WIN32)
|
||||
# Add common library search paths so executables and libraries can load and run
|
||||
# without LD_LIBRARY_PATH being set.
|
||||
link_libraries(
|
||||
"-Wl,-rpath,'$ORIGIN'"
|
||||
"-Wl,-rpath,'$ORIGIN/../lib64'"
|
||||
"-Wl,-rpath,'$ORIGIN/../lib'"
|
||||
"-Wl,-rpath,'$$ORIGIN'"
|
||||
"-Wl,-rpath,'$$ORIGIN/../lib64'"
|
||||
"-Wl,-rpath,'$$ORIGIN/../lib'"
|
||||
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'"
|
||||
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'"
|
||||
${CMAKE_DL_LIBS}
|
||||
@ -881,7 +944,7 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
|
||||
install(
|
||||
FILES ${__RESULT_CACHE_FILE}
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}/
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -1009,7 +1072,7 @@ function(cutlass_generate_profiler_tests NAME)
|
||||
|
||||
install(
|
||||
FILES ${CUTLASS_PROFILER_REGRESSION_LIST_FILE}
|
||||
DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass/
|
||||
DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass
|
||||
RENAME profiler_regressions.csv
|
||||
)
|
||||
|
||||
|
||||
153
README.md
153
README.md
@ -1,9 +1,9 @@
|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 4.0.0
|
||||
# CUTLASS 4.3.0
|
||||
|
||||
_CUTLASS 4.0.0 - May 2025_
|
||||
_CUTLASS 4.3.0 - Oct 2025_
|
||||
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
@ -40,47 +40,77 @@ designs, and bringing optimized solutions into production.
|
||||
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
|
||||
|
||||
To get started quickly - please refer :
|
||||
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
|
||||
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html).
|
||||
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html).
|
||||
|
||||
# What's New in CUTLASS 4.0
|
||||
# What's New in CUTLASS 4.3
|
||||
|
||||
## CuTe DSL
|
||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||
- [DSL Quick Start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||
* Debuggability improvements:
|
||||
- Supported source location tracking for DSL APIs
|
||||
- Supported dumping PTX and CUBIN code
|
||||
* More examples and notebooks to get started with CuTe DSL:
|
||||
- [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
|
||||
- Improved performance of elementwise kernel (https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/elementwise_apply.py):
|
||||
+ Generalize code to handle list of input tensors
|
||||
+ Generalize TV layout computation to handle different data types
|
||||
- Demonstrate the new Pipeline APIs in [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py):
|
||||
+ New Pipeline API `PipelineProducer` and `PipelineConsumer` to simplify code (no more explicit pipeline state management)
|
||||
- Separate epilogue code for non-TMA and TMA implementation
|
||||
+ Note that the updates simplifies the codes but existing APIs still work and are supported
|
||||
- [Basic Blackwell SM100 GEMM with decent performance](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py)
|
||||
+ Simple tutorial achieves 84% SOL performance with MNK 8K
|
||||
- Reworked [elementwise add notebook](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb) with more details and detailed explanation about TV layout
|
||||
+ Updated implementation to handle general data type and multiple inputs
|
||||
+ Updated explanation for TV layout in simpler language
|
||||
+ Added visualization of TV Layout with 3rd party utils
|
||||
- [Benchmark and autotune demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb)
|
||||
* More examples of authorizing peak-performance kernels:
|
||||
- [Blackwell SM100 mixed-input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py)
|
||||
- [Blackwell SM100 persistent blockwise dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py)
|
||||
- [Blackwell SM100 persistent blockwise contiguous grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py)
|
||||
- [Blackwell SM100 persistent blockwise masked grouped dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py)
|
||||
- [Blackwell SM100 fmha bwd](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha_bwd.py)
|
||||
- [Blackwell SM100 mla](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mla.py)
|
||||
- [Hopper SM90 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py)
|
||||
- [Blackwell GeForce batched dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py)
|
||||
- [Ampere HSTU Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/hstu_attention.py)
|
||||
* API updates:
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
* Bug fixings and improvements
|
||||
- Add mma_tiler_n=64 and mma_tiler_n=192 support in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py).
|
||||
- Fixed ``TensorSSA.reduce`` to support static value as initial value
|
||||
- Updated docstring for following APIs to be more concise and easier to understand:
|
||||
- ``make_layout_tv``
|
||||
- ``is_static``
|
||||
- ``PipelineAsync``
|
||||
- ``SmemAllocator``
|
||||
- Fixed documentation for ``pipeline``, ``utils`` and ``cute.math``
|
||||
|
||||
## CUTLASS C++
|
||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling.
|
||||
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
||||
- For example:
|
||||
`(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
`(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
|
||||
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
|
||||
* Support LSE output in Blackwell FMHA Forward kernel.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add softmax skip correction.
|
||||
- Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB.
|
||||
- Fix a dead hang issue caused by early return warp.
|
||||
* Add Ragged Contiguous Grouped gemm kernel in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/).
|
||||
- This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations.
|
||||
* Optimize group gemm kernels by enabling async TMA desc update.
|
||||
* Support Blackwell SM100 convolution stream-K kernel.
|
||||
- Unit tests: [fprop_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [dgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu), [wgrad_streamK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu).
|
||||
* Add profiler support for Blackwell SM100 and SM120 blockscaled sparse kernels.
|
||||
* Fix some kernel issues:
|
||||
- Fix a race check issue of Blackwell SM103 kernels by adding missing elect one for prefetch barrier initialization.
|
||||
- Allow user to directly specify the number of stages for Hopper sm90 mixed input gemm.
|
||||
- Remove warnings caused by cuda vector type alignment setting in CUDA 13.
|
||||
- Remove problematic `cutlass::int8_t` and replace it with `int8_t`.
|
||||
* Fix some profiler issues:
|
||||
- Add some missing reference kernels.
|
||||
- Add calculation of scale factor A and B in function `bytes_with_problem_shape` of block scaled profiler.
|
||||
|
||||
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
|
||||
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/latest/CHANGELOG.html) for details of all past releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
@ -122,7 +152,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
|
||||
[dedicated documentation directory](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html).
|
||||
|
||||
# Compatibility
|
||||
|
||||
@ -169,7 +199,10 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|
||||
|NVIDIA B300 Tensor Core GPU |10.3|13.0|
|
||||
|NVIDIA DRIVE Thor |11.0|13.0|
|
||||
|NVIDIA GeForce RTX 50x0 series |12.0|12.8|
|
||||
|NVIDIA DGX Spark |12.1|13.0|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
@ -201,11 +234,11 @@ cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs (SM120). As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
|
||||
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
@ -213,22 +246,22 @@ for details on which kernels require which target architectures.
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
|
||||
- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
|
||||
- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
|
||||
- [Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
|
||||
- [Functionality](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html) - describes terms used in the code
|
||||
- [Programming Guidelines](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
@ -248,7 +281,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -293,7 +326,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
|
||||
[CUTLASS documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/code_organization.html), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
@ -367,7 +400,7 @@ tools/
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
@ -583,9 +616,9 @@ reference_device: Passed
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
|
||||
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html)
|
||||
|
||||
|
||||
# About
|
||||
|
||||
@ -36,6 +36,7 @@ set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
|
||||
function(cutlass_generate_kernel_filter_and_testlist_files)
|
||||
|
||||
set(options)
|
||||
@ -65,7 +66,12 @@ endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
|
||||
set(PROFILER_ARCH_LIST 100a 100f 103a 120a 120f 121a)
|
||||
if (CUDA_VERSION VERSION_LESS 13.0)
|
||||
list(APPEND PROFILER_ARCH_LIST 101a 101f)
|
||||
else()
|
||||
list(APPEND PROFILER_ARCH_LIST 110a 110f)
|
||||
endif()
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")
|
||||
|
||||
@ -45,7 +45,7 @@
|
||||
cutlass::half_t
|
||||
|
||||
This is a numeric type implementing IEEE half-precision quantities. It is functional in host
|
||||
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables harware-accelerated
|
||||
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables hardware-accelerated
|
||||
numeric conversion on x86-64 CPUs support F16C extensions. In device code, all available
|
||||
hardware is used to implement conversion and numeric operations.
|
||||
|
||||
|
||||
@ -243,10 +243,11 @@ cudaError_t run_batched_gemm(bool use_array) {
|
||||
const char* gemm_desc = use_array ? "array" : "strided batched";
|
||||
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
|
||||
|
||||
// Arbitrary problem size
|
||||
// Arbitrary matrix shape
|
||||
int const m = 520;
|
||||
int const n = 219;
|
||||
int const k = 129;
|
||||
|
||||
int const batch_count = 17;
|
||||
|
||||
// A, B are non-transpose, column major
|
||||
|
||||
@ -64,7 +64,7 @@ ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutla
|
||||
ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not
|
||||
enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do
|
||||
that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB
|
||||
to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C
|
||||
to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C
|
||||
which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the
|
||||
data type of output ElementOutput (int32_t), the number of elements per vector memory access (16),
|
||||
data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X +
|
||||
|
||||
@ -64,7 +64,7 @@ ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t)
|
||||
(int32_t). Communicating just the data type is not enough. As the data is laid out linearly in
|
||||
memory, we have to convey the layout of matrices. We do that by initializing template variable
|
||||
LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row
|
||||
major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel.
|
||||
major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel.
|
||||
We initialize template variable EpilogueOp, which takes the data type of output ElementOutput
|
||||
(int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t)
|
||||
and data type of computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
@ -66,7 +66,7 @@ ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB
|
||||
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
|
||||
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
@ -177,7 +177,7 @@ public:
|
||||
if(args.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
||||
// The user needs to call a reduction operator to optain the final output tensor
|
||||
// The user needs to call a reduction operator to obtain the final output tensor
|
||||
workspace_bytes =
|
||||
sizeof(ElementAccumulator) *
|
||||
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *
|
||||
|
||||
@ -153,7 +153,7 @@ struct Options {
|
||||
|
||||
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
|
||||
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
|
||||
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " run in a single kernel. Each individual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " back-to-back GEMMs are subject to.s"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
|
||||
@ -248,7 +248,7 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// Epilogue params remain constant across all problems in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
@ -402,7 +402,7 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// Epilogue params remain constant across all problems in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
@ -434,7 +434,7 @@ struct B2bGemm {
|
||||
// Only row-major outputs are currently supported, so no transpose is performed
|
||||
}
|
||||
|
||||
/// Returns non-grouped paramaters to be used as input to the kernel-level
|
||||
/// Returns non-grouped parameters to be used as input to the kernel-level
|
||||
/// operator for the problem indicated by problem_visitor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
||||
|
||||
@ -560,7 +560,7 @@ struct DefaultB2bConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
|
||||
@ -606,7 +606,7 @@ struct DefaultB2bConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
|
||||
@ -277,7 +277,7 @@ public:
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
|
||||
@ -298,7 +298,7 @@ public:
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
|
||||
@ -93,7 +93,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
|
||||
@ -203,7 +203,7 @@ requires any memory for scratch space.
|
||||
If yes, we reserve scratch space and pass it along
|
||||
with other arguments to initialize the CUTLASS kernel.
|
||||
|
||||
After lauching the CUTLASS kernel, this example runs
|
||||
After launching the CUTLASS kernel, this example runs
|
||||
a reference convolution kernel (from CUTLASS utilities)
|
||||
to check correctness.
|
||||
*/
|
||||
|
||||
@ -144,7 +144,7 @@ int run() {
|
||||
// Construct Gemm ProblemSize with user defined output size
|
||||
cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024};
|
||||
|
||||
// Stride factor shows the distance between two elements in the differnet dimensions. The
|
||||
// Stride factor shows the distance between two elements in the different dimensions. The
|
||||
// first data is the logical distance between two rows, the second is between two columns.
|
||||
// CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory<Layout>::layout_factory
|
||||
// to help to convert stride_factor to the two strides.
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define the overal warp-level problem shape
|
||||
// Define the overall warp-level problem shape
|
||||
int const kM = 27;
|
||||
int const kN = 31;
|
||||
int const kK = 17;
|
||||
|
||||
@ -59,7 +59,7 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define the overal warp-level problem shape
|
||||
// Define the overall warp-level problem shape
|
||||
int const kM = 14;
|
||||
int const kN = 27;
|
||||
int const kK = 17;
|
||||
|
||||
@ -659,7 +659,7 @@ struct Testbed {
|
||||
}
|
||||
|
||||
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
|
||||
int64_t bytes = cutlass::bits_to_bytes(
|
||||
int64_t bytes = cutlass::bits_to_bytes<int64_t>(
|
||||
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
|
||||
options.problem_size.m() * options.problem_size.n());
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
// This example fuses gather before GEMM and scatter after GEMM into the same
|
||||
// GEMM kernel. Gather and scatter operation is controled by an index vector
|
||||
// GEMM kernel. Gather and scatter operation is controlled by an index vector
|
||||
// to select rows or columns from A, B, C or D matrices.
|
||||
//
|
||||
// Suppose, all matrices are column major. The pseudo code of the fused kernel
|
||||
|
||||
@ -87,7 +87,7 @@ public:
|
||||
using ElementLayernormCompute = ElementLayernormCompute_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
// Pre-processing has ensured the layout equivelent to RowMajor
|
||||
// Pre-processing has ensured the layout equivalent to RowMajor
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
using TensorVariance = TensorRef<ElementVariance, Layout>;
|
||||
|
||||
@ -33,8 +33,8 @@
|
||||
computing reference permutations of 4/5D tensors when source data is column-major.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cuda/std/cassert>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
@ -87,7 +87,7 @@ parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorC32RSK32"],
|
||||
|
||||
@ -86,7 +86,7 @@ parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
|
||||
@ -40,14 +40,12 @@
|
||||
Note that in general the fragment passed to the OutputOp could
|
||||
span multiple rows but it does not happen with the configurations we have
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
@ -42,12 +42,10 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
```
|
||||
|
||||
In practice, and for numerical stability reasons,
|
||||
we also substract the maximum so far (`mi`) before doing
|
||||
we also subtract the maximum so far (`mi`) before doing
|
||||
the exponential. When we encounter new keys, the maximum
|
||||
used to compute O so far (`m_prime`) can differ from the
|
||||
current maximum, so we update O before accumulating with
|
||||
|
||||
@ -55,7 +55,7 @@
|
||||
```
|
||||
|
||||
In practice, and for numerical stability reasons,
|
||||
we also substract the maximum so far (`mi`) before doing
|
||||
we also subtract the maximum so far (`mi`) before doing
|
||||
the exponential. When we encounter new keys, the maximum
|
||||
used to compute O so far (`m_prime`) can differ from the
|
||||
current maximum, so we update O before accumulating with
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
|
||||
/*! \file
|
||||
\brief Cutlass provides helper template functions to figure out the right
|
||||
datastructures to instanciate to run a GEMM with various parameters (see
|
||||
datastructures to instantiate to run a GEMM with various parameters (see
|
||||
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
|
||||
instantiation priority rules, it will only create an MmaMultiStage with
|
||||
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
|
||||
@ -83,7 +83,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
typename Enable_ = void>
|
||||
struct FindDefaultMma {
|
||||
|
||||
@ -522,7 +522,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
|
||||
// For API compatibility with MmaMultistageFromSharedMemory
|
||||
// but not supported as it worsens perf: older gpus < sm80 don't
|
||||
// support async tranfers and have to waste registers
|
||||
// support async transfers and have to waste registers
|
||||
CUTLASS_DEVICE
|
||||
void set_prologue_done(bool value) {}
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Instanciates the right WarpIterator to read from shared memory
|
||||
\brief Instantiates the right WarpIterator to read from shared memory
|
||||
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
|
||||
data dumped with `B2bGemm::accumToSmem`.
|
||||
*/
|
||||
|
||||
@ -86,7 +86,7 @@ namespace threadblock {
|
||||
/// To be efficient, this assumes the iterator will be dereferenced and advanced
|
||||
/// at least once outside any looping structure to minimize integer arithmetic.
|
||||
///
|
||||
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to
|
||||
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to
|
||||
/// dereferencing the iterator.
|
||||
///
|
||||
///
|
||||
|
||||
@ -49,7 +49,7 @@
|
||||
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
|
||||
for this example:
|
||||
a_rows - Rows in the sparse matrix.
|
||||
a_cols - Colums in the sparse matrix.
|
||||
a_cols - Columns in the sparse matrix.
|
||||
a_ell_blocksize - Size of the ELL-Blocks.
|
||||
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
|
||||
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)
|
||||
|
||||
@ -38,10 +38,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
@ -153,7 +153,7 @@ class gen_device:
|
||||
|
||||
warp_M_tile = 32
|
||||
|
||||
# Determine maxmimum N_tile
|
||||
# Determine maximum N_tile
|
||||
Max_Ntile = 0
|
||||
for layer in self.fuse_gemm_info:
|
||||
n_tile = layer['mnk'][1]
|
||||
|
||||
@ -76,9 +76,9 @@ class gen_verify:
|
||||
)
|
||||
|
||||
|
||||
def get_params(self, declartion = True):
|
||||
def get_params(self, declaration = True):
|
||||
code = ""
|
||||
if declartion:
|
||||
if declaration:
|
||||
for param in self.params:
|
||||
code += param[0] + " " + param[1] + ";\n"
|
||||
|
||||
|
||||
@ -64,8 +64,8 @@ def write_2_headfile(filename, file_dir, string):
|
||||
with open(file_dir + filename, 'w') as f:
|
||||
f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string)
|
||||
|
||||
def var_idx(varaiable, index):
|
||||
return varaiable + str(index)
|
||||
def var_idx(variable, index):
|
||||
return variable + str(index)
|
||||
|
||||
|
||||
def list_2_string(input_list, ):
|
||||
|
||||
@ -37,12 +37,10 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
@ -78,7 +78,7 @@
|
||||
a single default value.
|
||||
|
||||
CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of
|
||||
the collective is specified via the schedule tags that corresond to the underlying collective's
|
||||
the collective is specified via the schedule tags that correspond to the underlying collective's
|
||||
dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto`
|
||||
are special cases of these schedules that allow the builder to also decide the dispatch policy for you,
|
||||
therefore letting the builder pick the collective specialization.
|
||||
|
||||
@ -425,7 +425,7 @@ int main(int argc, char const **args) {
|
||||
// Pipeline Depth to be used i.e number of A, B buffers in shared memory
|
||||
constexpr int PipelineStages = 8;
|
||||
|
||||
// Let's choose a Warp-Specialized Mainloop implemention which uses TMA
|
||||
// Let's choose a Warp-Specialized Mainloop implementation which uses TMA
|
||||
// Note : This requires / assumes the tensors to be 16B aligned
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>;
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
\brief Example of a Hopper gather+GEMM+scatter kernel fusion.
|
||||
|
||||
This example fuses gather before GEMM and scatter after GEMM into the same
|
||||
GEMM kernel. Gather and scatter operation is controled by an index vector
|
||||
GEMM kernel. Gather and scatter operation is controlled by an index vector
|
||||
to select rows or columns from A, B, C or D matrices.
|
||||
|
||||
Gather/scatter operations are always performed along a strided dimension
|
||||
|
||||
@ -65,7 +65,7 @@
|
||||
The approach relies on two things:
|
||||
- The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the
|
||||
flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details).
|
||||
- The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
|
||||
- The hardware capabilities of Hopper TMA units that allow for loading multidimensional tensors with
|
||||
(almost) arbitrary strides, which can be used to represent a permuted view of the data.
|
||||
|
||||
In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags.
|
||||
|
||||
@ -188,7 +188,7 @@ Running this example on an RTX 3080Ti prints the following performance numbers (
|
||||
|
||||
```
|
||||
$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check
|
||||
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.
|
||||
Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.
|
||||
|
||||
Allocating tensors ... done.
|
||||
Initializing data ... done.
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel
|
||||
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel
|
||||
capable of operating on both affine and gather/scatter tensors.
|
||||
|
||||
This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off
|
||||
@ -284,7 +284,7 @@ int ampere_gather_scatter_conv_fprop(
|
||||
int
|
||||
main(int argc, char const** argv) {
|
||||
cutlass::CommandLine cmd(argc, argv);
|
||||
std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n";
|
||||
std::cout << "Ampere convolution forward propagation kernel supporting both affine and gather/scatter tensors.\n\n";
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
std::cout
|
||||
<< "Options:\n"
|
||||
|
||||
@ -291,7 +291,7 @@ struct Options {
|
||||
// Post-process the problem sizes
|
||||
bin_problems();
|
||||
|
||||
// Initalize alpha array
|
||||
// Initialize alpha array
|
||||
randomize_alpha_ptr_array(cmd);
|
||||
}
|
||||
|
||||
|
||||
@ -26,7 +26,9 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 90a)
|
||||
cutlass_example_add_executable(
|
||||
65_distributed_gemm
|
||||
65_distributed_gemm.cu
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -129,7 +129,7 @@ using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_confi
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
@ -358,7 +358,7 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
// Layout SFA and SFB represent logically broadcasting data in CuTe.
|
||||
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
|
||||
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
|
||||
// indecies in the tensor map to the same offset.
|
||||
// indices in the tensor map to the same offset.
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
@ -132,12 +132,12 @@ constexpr int ScaleGranularityK = 128;
|
||||
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
|
||||
@ -145,7 +145,7 @@ using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularity
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
@ -402,12 +402,37 @@ void initialize(const OptionType &options) {
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
// If the current group's matrix has size 0, set the pointer to nullptr
|
||||
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
|
||||
ptr_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
|
||||
ptr_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
|
||||
ptr_C_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
|
||||
ptr_D_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
|
||||
ptr_blockscale_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
|
||||
ptr_blockscale_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
}
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
@ -546,10 +571,10 @@ bool verify(const OptionType &options) {
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
@ -598,11 +623,7 @@ bool verify(const OptionType &options) {
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // Aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
@ -639,6 +660,24 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::string raster = "Heuristic";
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
@ -671,8 +710,7 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
@ -686,25 +724,6 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
|
||||
@ -132,8 +132,7 @@ using ElementCompute = float; // E
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
@ -142,13 +141,13 @@ static constexpr int ScaleGranularityK = 128;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
@ -407,12 +406,37 @@ void initialize(const OptionType &options) {
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
// If the current group's matrix has size 0, set the pointer to nullptr
|
||||
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
|
||||
ptr_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
|
||||
ptr_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
|
||||
ptr_C_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
|
||||
ptr_D_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
|
||||
ptr_blockscale_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
|
||||
ptr_blockscale_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
}
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
@ -551,10 +575,10 @@ bool verify(const OptionType &options) {
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
@ -637,10 +661,27 @@ bool verify(const OptionType &options) {
|
||||
template <typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::string raster = "Heuristic";
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
@ -695,27 +736,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
ScaleMsPerTile,
|
||||
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
std::cout << " GBPS: " << result.gbps << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return 0;
|
||||
@ -766,8 +790,8 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running tests with host problem shapes:" << std::endl;
|
||||
run(options, true);
|
||||
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run(options, false);
|
||||
|
||||
|
||||
@ -44,6 +44,9 @@ set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0)
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_K_16B_ALIGNED --m=256 --n=512 --k=960 --groups=10 --iterations=0)
|
||||
set(TEST_K_16B_ALIGNED_LARGE_GROUP --m=256 --n=512 --k=960 --groups=512 --iterations=0)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
|
||||
@ -58,6 +61,8 @@ cutlass_example_add_executable(
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_K_16B_ALIGNED
|
||||
TEST_K_16B_ALIGNED_LARGE_GROUP
|
||||
)
|
||||
|
||||
# MSVC will fail to compile this example with the following error:
|
||||
|
||||
@ -111,14 +111,14 @@ struct Options {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
|
||||
if (m < 0) {
|
||||
m = m_alignment * (rand() % (64 * alignment / m_alignment));
|
||||
}
|
||||
if (n < 1) {
|
||||
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
|
||||
if (n < 0) {
|
||||
n = n_alignment * (rand() % (64 * alignment / n_alignment));
|
||||
}
|
||||
if (k < 1) {
|
||||
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
|
||||
if (k < 0) {
|
||||
k = k_alignment * (rand() % (32 * alignment / k_alignment));
|
||||
}
|
||||
problem_sizes_after_alignment_host.push_back({m, n, k});
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
|
||||
@ -454,11 +454,12 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -640,11 +640,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -33,7 +33,7 @@ set(TEST_SWIZZLE_2 --swizzle=2)
|
||||
set(TEST_SWIZZLE_5 --swizzle=5)
|
||||
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp16_gemm
|
||||
70_blackwell_fp16_gemm.cu
|
||||
|
||||
@ -449,9 +449,9 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
|
||||
cutlass_example_add_executable(
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
71_blackwell_gemm_with_collective_builder.cu
|
||||
|
||||
@ -40,10 +40,10 @@
|
||||
|
||||
Similar to 70_blackwell_gemm, this kernel leverages:
|
||||
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
|
||||
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
@ -116,10 +116,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -190,13 +190,7 @@ cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_referen
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -329,7 +323,7 @@ bool initialize_block(
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -413,7 +407,7 @@ bool verify(const Options &options) {
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
@ -514,13 +508,13 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -39,10 +39,10 @@
|
||||
1. Blockscaled tcgen05.mma instructions.
|
||||
|
||||
2. Per-SM memory called Tensor Memory (TMEM)
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
@ -129,13 +129,13 @@ constexpr int OutputSFVectorSize = InputSFVectorSize;
|
||||
// With BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, LayoutSFDTag,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -219,13 +219,7 @@ cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_N
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -358,7 +352,7 @@ bool initialize_block(
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -456,7 +450,7 @@ bool verify(const Options &options) {
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
|
||||
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
@ -569,11 +563,11 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -41,10 +41,10 @@
|
||||
1. Blockscaled tcgen05.mma instructions.
|
||||
|
||||
2. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
@ -117,10 +117,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -191,13 +191,7 @@ cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_referen
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -330,7 +324,7 @@ bool initialize_block(
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -414,7 +408,7 @@ bool verify(const Options &options) {
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
@ -515,14 +509,14 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
|
||||
cutlass_example_add_executable(
|
||||
72a_blackwell_nvfp4_bf16_gemm
|
||||
72a_blackwell_nvfp4_bf16_gemm.cu
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
|
||||
cutlass_example_add_executable(
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
blackwell_gemm_preferred_cluster.cu
|
||||
|
||||
@ -513,7 +513,7 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -29,9 +29,9 @@
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
74_blackwell_gemm_streamk
|
||||
blackwell_gemm_streamk.cu
|
||||
if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f")
|
||||
cutlass_example_add_executable(
|
||||
74_blackwell_gemm_streamk
|
||||
blackwell_gemm_streamk.cu
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -61,7 +61,7 @@
|
||||
# Heuristic mode with deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic
|
||||
|
||||
# Stream-K mode with determinsitic reduction
|
||||
# Stream-K mode with deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic
|
||||
|
||||
# Split-K mode with a splitting factor of 2 and deterministic reduction
|
||||
@ -556,10 +556,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -762,9 +762,8 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -130,16 +130,15 @@ constexpr int OutputSFVectorSize = 16;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementAccumulator,
|
||||
ElementD,
|
||||
ElementAccumulator,
|
||||
ElementSFD,
|
||||
LayoutC,
|
||||
ElementC>;
|
||||
|
||||
// Core kernel configurations
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
// Runtime Cluster Shape
|
||||
@ -159,7 +158,7 @@ struct MMA2SMConfig {
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
typename MMA1SMConfig::MmaTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -169,7 +168,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
@ -187,7 +186,7 @@ using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
|
||||
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -197,13 +196,13 @@ using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::Collective
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>,
|
||||
typename MMA2SMConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal<
|
||||
@ -222,7 +221,7 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutS
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
|
||||
OutputSFVectorSize,
|
||||
OutputSFVectorSize,
|
||||
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
|
||||
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
|
||||
>;
|
||||
@ -233,7 +232,7 @@ using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFA> layout_SFB_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
@ -287,13 +286,7 @@ cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -529,7 +522,7 @@ bool initialize_block(
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -785,9 +778,9 @@ bool verify(const Options &options) {
|
||||
decltype(tensor_SFA),
|
||||
decltype(tensor_B),
|
||||
decltype(tensor_SFB)
|
||||
>
|
||||
>
|
||||
mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C);
|
||||
auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D);
|
||||
|
||||
@ -856,8 +849,8 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cout << " Verfication is turned off for this run." << std::endl;
|
||||
}
|
||||
std::cout << " Verification is turned off for this run." << std::endl;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
@ -903,9 +896,8 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -933,7 +925,7 @@ int main(int argc, char const **args) {
|
||||
std::cout << "Running kernel with 1SM MMA config:" << std::endl;
|
||||
run<Gemm1SM>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "Running kernel with 2SM MMA config:" << std::endl;
|
||||
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
|
||||
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if("100a" IN_LIST CUTLASS_NVCC_ARCHS)
|
||||
cutlass_example_add_executable(
|
||||
75_blackwell_grouped_gemm
|
||||
75_blackwell_grouped_gemm.cu
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
|
||||
Xformed Activation (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation
|
||||
@ -504,10 +504,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of fprop convolution kernel is, take 3D convolution as an example:
|
||||
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK)
|
||||
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Activation (NZPQK)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation
|
||||
@ -504,10 +504,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -36,7 +36,7 @@
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
|
||||
Xformed Activation (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter
|
||||
@ -500,10 +500,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -116,6 +116,8 @@ struct Options {
|
||||
int h_k = 1;
|
||||
int q = 256;
|
||||
int k = 256;
|
||||
std::vector<int> varlen_q;
|
||||
std::vector<int> varlen_k;
|
||||
int d = 128;
|
||||
int warmup_iterations = 1;
|
||||
int iterations = 3;
|
||||
@ -124,6 +126,7 @@ struct Options {
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool causal_q_begin = true;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
@ -181,13 +184,76 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
|
||||
std::string varlen_q_str;
|
||||
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
|
||||
std::string varlen_k_str;
|
||||
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
|
||||
|
||||
if (varlen && ! varlen_q_str.empty()) {
|
||||
varlen_q.clear();
|
||||
while (! varlen_q_str.empty()) {
|
||||
size_t pos = varlen_q_str.find(':');
|
||||
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
varlen_q_str = varlen_q_str.substr(pos + 1);
|
||||
}
|
||||
if (b == -1) {
|
||||
b = static_cast<int>(varlen_q.size());
|
||||
}
|
||||
if (b != static_cast<int>(varlen_q.size())) {
|
||||
std::cout << "Error: Invalid --varlen-q length\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
int new_q = 0;
|
||||
for (auto elem : varlen_q) {
|
||||
new_q += elem;
|
||||
}
|
||||
if (q != -1) {
|
||||
std::cout << "Error: Can't provide --q and --varlen-q\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
q = new_q;
|
||||
}
|
||||
|
||||
if (varlen && ! varlen_k_str.empty()) {
|
||||
varlen_k.clear();
|
||||
while (! varlen_k_str.empty()) {
|
||||
size_t pos = varlen_k_str.find(':');
|
||||
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
varlen_k_str = varlen_k_str.substr(pos + 1);
|
||||
}
|
||||
if (b == -1) {
|
||||
b = static_cast<int>(varlen_k.size());
|
||||
}
|
||||
if (b != static_cast<int>(varlen_k.size())) {
|
||||
std::cout << " Error: Invalid --varlen-k length\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
int new_k = 0;
|
||||
for (auto elem : varlen_k) {
|
||||
new_k += elem;
|
||||
}
|
||||
if (k != -1) {
|
||||
std::cout << "Error: Can't provide --k and --varlen-k\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
k = new_k;
|
||||
}
|
||||
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
@ -197,11 +263,12 @@ struct Options {
|
||||
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
persistent = cmd.check_cmd_line_flag("persistent");
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
std::string causal_type;
|
||||
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
@ -211,6 +278,11 @@ struct Options {
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
if(causal_type == "qend") {
|
||||
causal_q_begin = false;
|
||||
} else {
|
||||
causal_q_begin = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
@ -240,13 +312,16 @@ struct Options {
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
|
||||
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
|
||||
<< " --d=<int> Sets the D extent\n"
|
||||
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
|
||||
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --causal-type=<qbegin|qend> Causal mask type\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
@ -344,16 +419,16 @@ struct FwdRunner {
|
||||
using ElementAccumulatorPV = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
// Q K D (B H)
|
||||
// Q K D ((H_R, H_K) B)
|
||||
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
|
||||
|
||||
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
|
||||
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
|
||||
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D ((H_R, H_K), B)
|
||||
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D ((H_R, H_K), B)
|
||||
using StrideV = StrideK;
|
||||
using StrideO = StrideQ;
|
||||
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
|
||||
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q ((H_R, H_K), B)
|
||||
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
|
||||
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||
@ -439,8 +514,12 @@ struct FwdRunner {
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
auto problem_shape_ref = cute::make_tuple(Q, K, D, D, HB);
|
||||
|
||||
fmha_reference(problem_shape_ref, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
@ -475,7 +554,10 @@ struct FwdRunner {
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
|
||||
auto initialize_varlen(
|
||||
const Options& options, const ProblemShape& problem_size,
|
||||
const bool kVarlenSame = true) {
|
||||
|
||||
int num_batches = get<3,1>(problem_size);
|
||||
|
||||
// generate Q as --b times
|
||||
@ -503,8 +585,12 @@ struct FwdRunner {
|
||||
int max_seqlen_kv = 0;
|
||||
|
||||
for (int i = 0; i < num_batches; i++) {
|
||||
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
|
||||
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
|
||||
kVarlenSame ? get<0>(problem_size) :
|
||||
generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
|
||||
kVarlenSame ? get<1>(problem_size) :
|
||||
generate_positive_int(dist_kv, rng);
|
||||
|
||||
total_seqlen_q += seqlen_q;
|
||||
total_seqlen_kv += seqlen_kv;
|
||||
@ -525,8 +611,8 @@ struct FwdRunner {
|
||||
|
||||
ProblemShapeType problem_size_for_launch;
|
||||
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
|
||||
get<2>(problem_size_for_launch) = get<2>(problem_size);
|
||||
get<3>(problem_size_for_launch) = get<3>(problem_size);
|
||||
|
||||
@ -545,7 +631,7 @@ struct FwdRunner {
|
||||
decltype(problem_shape_in) problem_size;
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
|
||||
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(options, problem_shape_in);
|
||||
problem_shape = problem_shape_launch;
|
||||
problem_size = problem_shape_init;
|
||||
}
|
||||
@ -583,11 +669,13 @@ struct FwdRunner {
|
||||
}
|
||||
|
||||
auto buffer_init_fn = [&](auto& buffer) {
|
||||
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_Q.reset(size(shape_QO));
|
||||
buffer.block_K.reset(size(shape_KV));
|
||||
buffer.block_V.reset(size(shape_KV));
|
||||
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_LSE.reset(size(shape_LSE));
|
||||
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_ref_LSE.reset(size(shape_LSE));
|
||||
|
||||
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
|
||||
@ -778,7 +866,7 @@ struct FwdRunner {
|
||||
flops *= static_cast<double>(size<1>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,1>(problem_shape));
|
||||
}
|
||||
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(size<2>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,0>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
@ -813,11 +901,18 @@ struct FwdRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
@ -992,7 +1087,11 @@ int main_single(int argc, char const **args) {
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
if(options.causal_q_begin) {
|
||||
fn(CausalMask{});
|
||||
} else {
|
||||
fn(CausalMask<false>{});
|
||||
}
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
@ -1018,7 +1117,7 @@ int main_single(int argc, char const **args) {
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1026,8 +1125,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -1054,7 +1151,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
|
||||
|
||||
This example showcases the use of CUTLASS to build backward fused
|
||||
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
|
||||
multi-head attention (FMHA) collectives from existing CUTLASS collectives targeting
|
||||
the NVIDIA Blackwell architecture.
|
||||
|
||||
Background and motivation
|
||||
@ -114,12 +114,17 @@ struct Options {
|
||||
int h_k = 1;
|
||||
int q = 1024;
|
||||
int k = 1024;
|
||||
std::vector<int> varlen_q;
|
||||
std::vector<int> varlen_k;
|
||||
int d = 128;
|
||||
int d_vo = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
@ -174,16 +179,82 @@ struct Options {
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("d_vo", d_vo, d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
std::string varlen_q_str;
|
||||
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
|
||||
std::string varlen_k_str;
|
||||
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
|
||||
|
||||
if (varlen && ! varlen_q_str.empty()) {
|
||||
varlen_q.clear();
|
||||
while (! varlen_q_str.empty()) {
|
||||
size_t pos = varlen_q_str.find(':');
|
||||
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
varlen_q_str = varlen_q_str.substr(pos + 1);
|
||||
}
|
||||
if (b == -1) {
|
||||
b = static_cast<int>(varlen_q.size());
|
||||
}
|
||||
if (b != static_cast<int>(varlen_q.size())) {
|
||||
std::cout << "Error: Invalid --varlen-q length\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
int new_q = 0;
|
||||
for (auto elem : varlen_q) {
|
||||
new_q += elem;
|
||||
}
|
||||
if (q != -1) {
|
||||
std::cout << "Error: Can't provide --q and --varlen-q\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
q = new_q;
|
||||
}
|
||||
|
||||
if (varlen && ! varlen_k_str.empty()) {
|
||||
varlen_k.clear();
|
||||
while (! varlen_k_str.empty()) {
|
||||
size_t pos = varlen_k_str.find(':');
|
||||
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
varlen_k_str = varlen_k_str.substr(pos + 1);
|
||||
}
|
||||
if (b == -1) {
|
||||
b = static_cast<int>(varlen_k.size());
|
||||
}
|
||||
if (b != static_cast<int>(varlen_k.size())) {
|
||||
std::cout << " Error: Invalid --varlen-k length\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
int new_k = 0;
|
||||
for (auto elem : varlen_k) {
|
||||
new_k += elem;
|
||||
}
|
||||
if (k != -1) {
|
||||
std::cout << "Error: Can't provide --k and --varlen-k\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
k = new_k;
|
||||
}
|
||||
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
@ -195,9 +266,15 @@ struct Options {
|
||||
if (mask == "causal") {
|
||||
causal = true;
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
}
|
||||
else {
|
||||
causal = defaults.causal;
|
||||
}
|
||||
if (varlen) {
|
||||
residual = true;
|
||||
}
|
||||
|
||||
skip_reference = cmd.check_cmd_line_flag("skip-reference");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
@ -224,13 +301,22 @@ struct Options {
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
|
||||
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
|
||||
<< " --d=<int> Sets the D extent\n"
|
||||
<< " --d_vo=<int> Sets the D_VO extent\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|causal> Enables masking\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
<< " with the last batch sized to make it fit\n"
|
||||
<< " implies at least residual masking for correctness\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
@ -307,6 +393,8 @@ struct ExampleResult {
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
bool kIsVarlen,
|
||||
bool kIsMla,
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
@ -321,23 +409,24 @@ struct BwdRunner {
|
||||
#endif
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Q K D (H B)
|
||||
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||
// Q K D D_VO ((H_R, H_K) B)
|
||||
using ProblemShape = std::conditional_t<
|
||||
kIsVarlen,
|
||||
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
|
||||
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
|
||||
>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
using StrideQ = TensorStride;
|
||||
using StrideK = TensorStride;
|
||||
using StrideV = TensorStride;
|
||||
using StrideO = TensorStride;
|
||||
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
|
||||
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
|
||||
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
|
||||
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
|
||||
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
|
||||
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
|
||||
|
||||
// Backwards specific
|
||||
using StrideDQ = TensorStride;
|
||||
using StrideDK = TensorStride;
|
||||
using StrideDV = TensorStride;
|
||||
using StrideDO = TensorStride;
|
||||
using StrideDQ = StrideQ;
|
||||
using StrideDK = StrideK;
|
||||
using StrideDV = StrideV;
|
||||
using StrideDO = StrideO;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -363,6 +452,9 @@ struct BwdRunner {
|
||||
DeviceAllocation<Element> block_O;
|
||||
DeviceAllocation<ElementAccumulator> block_LSE;
|
||||
|
||||
DeviceAllocation<int> block_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> block_cumulative_seqlen_kv;
|
||||
|
||||
DeviceAllocation<Element> block_dQ;
|
||||
DeviceAllocation<Element> block_dK;
|
||||
DeviceAllocation<Element> block_dV;
|
||||
@ -375,47 +467,19 @@ struct BwdRunner {
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
// keep going here! (this might be better in cursor)
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_dQ);
|
||||
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_dK);
|
||||
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_dV);
|
||||
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_dO);
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q);
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K);
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V);
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O);
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE);
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ);
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK);
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV);
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO);
|
||||
|
||||
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
|
||||
|
||||
@ -459,22 +523,94 @@ struct BwdRunner {
|
||||
return passed_dQ && passed_dK && passed_dV;
|
||||
}
|
||||
|
||||
auto initialize_problem_shape(Options const& options) {
|
||||
int h_r = options.h / options.h_k;
|
||||
assert(options.h % options.h_k == 0);
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
int num_batches = options.b;
|
||||
|
||||
// generate Q as --b times
|
||||
// gaussian (--Q, --Q / 2) sampled positive
|
||||
// track cumulative
|
||||
std::mt19937 rng(0x202305151552ull);
|
||||
std::normal_distribution<double> dist_q(options.q, options.q / 2);
|
||||
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
|
||||
|
||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
||||
// "0" is a valid value we test here
|
||||
return std::max(0, static_cast<int>(dist(gen)));
|
||||
};
|
||||
|
||||
std::vector<int> cumulative_seqlen_q = {0};
|
||||
std::vector<int> cumulative_seqlen_kv = {0};
|
||||
|
||||
int total_seqlen_q = 0;
|
||||
int total_seqlen_kv = 0;
|
||||
int max_seqlen_q = 0;
|
||||
int max_seqlen_kv = 0;
|
||||
|
||||
const bool kVarlenSame = false;
|
||||
for (int i = 0; i < num_batches; i++) {
|
||||
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
|
||||
kVarlenSame ? options.q :
|
||||
generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
|
||||
kVarlenSame ? options.k :
|
||||
generate_positive_int(dist_kv, rng);
|
||||
|
||||
total_seqlen_q += seqlen_q;
|
||||
total_seqlen_kv += seqlen_kv;
|
||||
|
||||
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
|
||||
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
|
||||
|
||||
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
|
||||
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
|
||||
}
|
||||
|
||||
block_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
block_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
block_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
block_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
|
||||
ProblemShape problem_shape{
|
||||
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
|
||||
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
|
||||
options.d, options.d_vo, {{h_r, options.h_k}, options.b}
|
||||
};
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1));
|
||||
|
||||
return cute::make_tuple(problem_shape, tensor_shape);
|
||||
}
|
||||
else {
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}};
|
||||
return cute::make_tuple(problem_shape, problem_shape);
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
ProblemShape initialize(Options const& options) {
|
||||
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
|
||||
auto [Q, K, D, D_VO, HB] = tensor_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
|
||||
auto shape_QO = select<0,2,3>(problem_shape);
|
||||
auto shape_KV = select<1,2,3>(problem_shape);
|
||||
auto shape_LSE = select<0,3>(problem_shape);
|
||||
// for varlen, Q == total_Q, K == total_K, B = 1
|
||||
// but in problem_shape, they've got to be max_Q/max_K, and B = B
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
|
||||
stride_V = stride_K;
|
||||
stride_O = stride_Q;
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto shape_Q = make_shape(Q, D, HB);
|
||||
auto shape_K = make_shape(K, D, HB);
|
||||
auto shape_V = make_shape(K, D_VO, HB);
|
||||
auto shape_O = make_shape(Q, D_VO, HB);
|
||||
auto shape_LSE = make_shape(Q, HB);
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
|
||||
stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K));
|
||||
stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K));
|
||||
stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K));
|
||||
stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
@ -485,58 +621,72 @@ struct BwdRunner {
|
||||
return size(make_shape(1ull, shape));
|
||||
};
|
||||
|
||||
block_Q.reset(lsize(shape_QO));
|
||||
block_K.reset(lsize(shape_KV));
|
||||
block_V.reset(lsize(shape_KV));
|
||||
block_O.reset(lsize(shape_QO));
|
||||
auto size_K = lsize(K * D * H_K * B);
|
||||
auto size_V = lsize(K * D_VO * H_K * B);
|
||||
|
||||
block_Q.reset(lsize(shape_Q));
|
||||
block_K.reset(size_K);
|
||||
block_V.reset(size_V);
|
||||
block_O.reset(lsize(shape_O));
|
||||
block_LSE.reset(lsize(shape_LSE));
|
||||
|
||||
block_dQ.reset(lsize(shape_QO));
|
||||
block_dK.reset(lsize(shape_KV));
|
||||
block_dV.reset(lsize(shape_KV));
|
||||
block_dO.reset(lsize(shape_QO));
|
||||
block_dQ.reset(lsize(shape_Q));
|
||||
block_dK.reset(size_K);
|
||||
block_dV.reset(size_V);
|
||||
block_dO.reset(lsize(shape_O));
|
||||
|
||||
block_ref_dQ.reset(lsize(shape_QO));
|
||||
block_ref_dK.reset(lsize(shape_KV));
|
||||
block_ref_dV.reset(lsize(shape_KV));
|
||||
block_ref_dQ.reset(lsize(shape_Q));
|
||||
block_ref_dK.reset(size_K);
|
||||
block_ref_dV.reset(size_V);
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(block_dO, seed + 2020, options.init_style_do);
|
||||
|
||||
initialize_block(block_dQ, seed + 2030, InitStyle::kOne);
|
||||
initialize_block(block_dK, seed + 2031, InitStyle::kOne);
|
||||
initialize_block(block_dV, seed + 2032, InitStyle::kOne);
|
||||
initialize_block(block_ref_dQ, seed + 2033);
|
||||
initialize_block(block_ref_dK, seed + 2034);
|
||||
initialize_block(block_ref_dV, seed + 2035);
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
if (! options.skip_reference) {
|
||||
if (not options.skip_reference) {
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
|
||||
|
||||
initialize(problem_shape, options);
|
||||
auto problem_shape = initialize(options);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<ProblemShape, Element, ElementAccumulator, TileShape, kIsMla, ActiveMask>;
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
block_Q.get(), stride_Q,
|
||||
@ -554,8 +704,6 @@ struct BwdRunner {
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
@ -650,12 +798,13 @@ struct BwdRunner {
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask<false>> || std::is_same_v<ActiveMask, CausalForBackwardMask<true>> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= static_cast<double>(get<2>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,1>(problem_shape));
|
||||
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
|
||||
flops *= static_cast<double>(get<4,0,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,0,1>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,1>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
@ -688,11 +837,18 @@ struct BwdRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
@ -706,19 +862,33 @@ void print_result(const std::string& description, ExampleResult result, bool ver
|
||||
|
||||
struct KernelCoop {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Fn>
|
||||
auto dispatch_bool(bool value, Fn fn) {
|
||||
if (value) {
|
||||
return fn(std::true_type{});
|
||||
}
|
||||
else {
|
||||
return fn(std::false_type{});
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, false,decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
};
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -726,14 +896,31 @@ void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
template<class Mask>
|
||||
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, false, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
};
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
run(Shape<_128, _128, HeadDim, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_mla_192(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
dispatch_bool(options.varlen, [&](auto is_varlen) {
|
||||
BwdRunner<decltype(is_varlen)::value, true, decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
});
|
||||
};
|
||||
|
||||
using HeadDim = _192;
|
||||
|
||||
run(Shape<_64, _128, HeadDim, _128>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -797,13 +984,16 @@ int main_single(int argc, char const **args) {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
|
||||
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
auto with_causal = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
fn(CausalForBackwardMask{});
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMaskForBackward{});
|
||||
}
|
||||
else {
|
||||
fn(NoMask{});
|
||||
@ -811,19 +1001,22 @@ int main_single(int argc, char const **args) {
|
||||
};
|
||||
|
||||
with_causal([&](auto fusion) {
|
||||
if (options.d <= 64) {
|
||||
if (options.d <= 64 && options.d_vo == options.d) {
|
||||
run_bwd_64(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 128) {
|
||||
else if (options.d <= 128 && options.d_vo == options.d) {
|
||||
run_bwd_128(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d == 192 && options.d_vo == 128) {
|
||||
run_bwd_mla_192(fusion, options, hw_info);
|
||||
}
|
||||
else {
|
||||
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -831,8 +1024,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -859,7 +1050,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -689,11 +689,18 @@ struct ExampleRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
|
||||
if (result.supported && ! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
|
||||
@ -781,12 +788,17 @@ int main_single(int argc, char const **args) {
|
||||
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
|
||||
)
|
||||
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
if (options.d == 128) {
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
}
|
||||
else {
|
||||
std::cout << "Head Dimension != 128 is not supported for the fmha_gen example\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
@ -797,8 +809,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -825,7 +835,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -80,6 +80,7 @@ struct Options {
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
bool is_fused_reduction = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
@ -139,9 +140,12 @@ struct Options {
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
|
||||
if (split_kv == 0) {
|
||||
split_kv = 1;
|
||||
}
|
||||
cmd.get_cmd_line_argument("page", page, defaults.page);
|
||||
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
|
||||
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
|
||||
is_var_split_kv = cmd.check_cmd_line_flag("var_split_kv");
|
||||
if (page == -1) {
|
||||
is_var_split_kv = false;
|
||||
}
|
||||
@ -149,6 +153,10 @@ struct Options {
|
||||
if (is_var_split_kv == true) {
|
||||
split_kv = max_split_kv;
|
||||
}
|
||||
is_fused_reduction = cmd.check_cmd_line_flag("fuse_reduction");
|
||||
if (split_kv == 1) {
|
||||
is_fused_reduction = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
@ -176,6 +184,8 @@ struct Options {
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --spread=<float> Relative spread away from K for paging\n"
|
||||
<< " --split_kv=<int> Split KV factor\n"
|
||||
<< " --fused_reduction Fuse the reduction operation\n"
|
||||
<< " --var_split_kv Use varying split KV factor\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
@ -391,11 +401,7 @@ struct Runner {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
#ifdef B2B
|
||||
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#else
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#endif
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
@ -404,7 +410,6 @@ struct Runner {
|
||||
}
|
||||
|
||||
bool passed_LSE = true;
|
||||
#ifndef B2B
|
||||
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
@ -412,7 +417,6 @@ struct Runner {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
@ -520,7 +524,8 @@ struct Runner {
|
||||
stride_LSE},
|
||||
hw_info,
|
||||
options.split_kv,
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr,
|
||||
options.is_fused_reduction
|
||||
};
|
||||
if (options.split_kv < 0 && !options.is_var_split_kv) {
|
||||
Operation::set_split_kv(arguments);
|
||||
@ -678,11 +683,18 @@ struct Runner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
|
||||
@ -723,13 +735,17 @@ void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
if (!options.is_fused_reduction || options.split_kv == 1) {
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
}
|
||||
#elif FP16
|
||||
name += " fp16";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
if (!options.is_fused_reduction || options.split_kv == 1) {
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -806,8 +822,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -834,7 +848,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
1087
examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu
Normal file
1087
examples/77_blackwell_fmha/77_blackwell_mla_fwd.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -33,24 +33,84 @@ set_property(
|
||||
77_blackwell_fmha_gen.cu
|
||||
77_blackwell_mla.cu
|
||||
77_blackwell_fmha_bwd.cu
|
||||
77_blackwell_mla_fwd.cu
|
||||
PROPERTY
|
||||
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
|
||||
)
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_CAUSAL_00 --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_CAUSAL_01 --verify --iterations=0 --b=1 --h=1 --h_k=1 --q=1013 --k=1024 --d=128 --mask=causal --causal-type=qend)
|
||||
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
|
||||
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
|
||||
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
|
||||
|
||||
set(TEST_VARLEN_00 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_01 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_02 --verify --varlen --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_03 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_04 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_05 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_06 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
|
||||
set(TEST_VARLEN_07 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
|
||||
set(TEST_VARLEN_08 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
|
||||
set(TEST_VARLEN_09 --verify --varlen --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
|
||||
set(TEST_VARLEN_10 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
|
||||
set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
|
||||
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
|
||||
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
|
||||
set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
|
||||
|
||||
|
||||
|
||||
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_02 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_03 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_04 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_05 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_06 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
|
||||
set(TEST_MLA_FWD_VARLEN_07 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
|
||||
set(TEST_MLA_FWD_VARLEN_08 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
|
||||
set(TEST_MLA_FWD_VARLEN_09 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
|
||||
set(TEST_MLA_FWD_VARLEN_10 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=2:3 --varlen-k=2:5)
|
||||
set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=11:10 --varlen-k=13:10)
|
||||
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
|
||||
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
|
||||
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_MLA_FWD_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
|
||||
set(TEST_MLA_FWD_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
|
||||
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
|
||||
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
|
||||
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
|
||||
|
||||
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
|
||||
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
|
||||
|
||||
set(TEST_MLA_LARGE_SPLIT_KV --verify --split_kv=20 --page=128)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
@ -62,10 +122,34 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
TEST_CAUSAL_00
|
||||
TEST_CAUSAL_01
|
||||
TEST_VARLEN
|
||||
TEST_HDIM64
|
||||
TEST_GQA
|
||||
TEST_VARLEN_00
|
||||
TEST_VARLEN_01
|
||||
TEST_VARLEN_02
|
||||
TEST_VARLEN_03
|
||||
TEST_VARLEN_04
|
||||
TEST_VARLEN_05
|
||||
TEST_VARLEN_06
|
||||
TEST_VARLEN_07
|
||||
TEST_VARLEN_08
|
||||
TEST_VARLEN_09
|
||||
TEST_VARLEN_10
|
||||
TEST_VARLEN_11
|
||||
TEST_VARLEN_12
|
||||
TEST_VARLEN_13
|
||||
TEST_VARLEN_14
|
||||
TEST_VARLEN_15
|
||||
TEST_VARLEN_16
|
||||
TEST_VARLEN_17
|
||||
TEST_VARLEN_18
|
||||
TEST_VARLEN_19
|
||||
TEST_VARLEN_20
|
||||
TEST_VARLEN_21
|
||||
TEST_VARLEN_22
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -75,11 +159,11 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
TEST_GEN_GQA
|
||||
TEST_GEN_REMAP
|
||||
TEST_GEN_CACHEONLY
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -89,6 +173,9 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
TEST_MLA_LARGE_SPLIT_KV
|
||||
TEST_MLA_SEP_REDUCTION
|
||||
TEST_MLA_FUSE_REDUCTION
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -99,50 +186,79 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
TEST_MLA_LARGE_SPLIT_KV
|
||||
TEST_MLA_SEP_REDUCTION
|
||||
TEST_MLA_FUSE_REDUCTION
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_b2b_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
|
||||
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
TEST_VARLEN
|
||||
# NOTE: bwd doesn't support GQA yet, --h_k will just get ignored in these tests
|
||||
TEST_VARLEN_00
|
||||
TEST_VARLEN_01
|
||||
TEST_VARLEN_02
|
||||
TEST_VARLEN_03
|
||||
TEST_VARLEN_04
|
||||
TEST_VARLEN_05
|
||||
TEST_VARLEN_06
|
||||
TEST_VARLEN_07
|
||||
TEST_VARLEN_08
|
||||
TEST_VARLEN_09
|
||||
TEST_VARLEN_10
|
||||
TEST_VARLEN_11
|
||||
TEST_VARLEN_12
|
||||
TEST_VARLEN_13
|
||||
TEST_VARLEN_14
|
||||
TEST_BWD_MLA_BASIC
|
||||
TEST_BWD_MLA_VARLEN
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_sat_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
77_blackwell_mla_fwd_${PREC}
|
||||
77_blackwell_mla_fwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
TEST_CAUSAL_00
|
||||
TEST_VARLEN
|
||||
TEST_HDIM64
|
||||
TEST_GQA
|
||||
TEST_MLA_FWD_VARLEN_00
|
||||
TEST_MLA_FWD_VARLEN_01
|
||||
TEST_MLA_FWD_VARLEN_02
|
||||
TEST_MLA_FWD_VARLEN_03
|
||||
TEST_MLA_FWD_VARLEN_04
|
||||
TEST_MLA_FWD_VARLEN_05
|
||||
TEST_MLA_FWD_VARLEN_06
|
||||
TEST_MLA_FWD_VARLEN_07
|
||||
TEST_MLA_FWD_VARLEN_08
|
||||
TEST_MLA_FWD_VARLEN_09
|
||||
TEST_MLA_FWD_VARLEN_10
|
||||
TEST_MLA_FWD_VARLEN_11
|
||||
TEST_MLA_FWD_VARLEN_12
|
||||
TEST_MLA_FWD_VARLEN_13
|
||||
TEST_MLA_FWD_VARLEN_14
|
||||
TEST_MLA_FWD_VARLEN_15
|
||||
TEST_MLA_FWD_VARLEN_16
|
||||
TEST_MLA_FWD_VARLEN_17
|
||||
TEST_MLA_FWD_VARLEN_18
|
||||
TEST_MLA_FWD_VARLEN_19
|
||||
TEST_MLA_FWD_VARLEN_20
|
||||
TEST_MLA_FWD_VARLEN_21
|
||||
TEST_MLA_FWD_VARLEN_22
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_mla_fwd_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
|
||||
# Add a target that builds all examples
|
||||
@ -156,11 +272,9 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla_2sm_fp16
|
||||
77_blackwell_mla_2sm_cpasync_fp8
|
||||
77_blackwell_mla_2sm_cpasync_fp16
|
||||
77_blackwell_mla_b2b_2sm_fp8
|
||||
77_blackwell_mla_b2b_2sm_fp16
|
||||
77_blackwell_fmha_bwd_fp8
|
||||
77_blackwell_fmha_bwd_fp16
|
||||
77_blackwell_fmha_bwd_sat_fp8
|
||||
77_blackwell_fmha_bwd_sat_fp16
|
||||
77_blackwell_mla_fwd_fp8
|
||||
77_blackwell_mla_fwd_fp16
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit
|
||||
|
||||
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
|
||||
|
||||
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
|
||||
For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
|
||||
The kernel and collective layer are then formulated to be fmha-specific.
|
||||
@ -37,13 +37,19 @@ There are three kernels to compute backwards:
|
||||
|
||||
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
|
||||
|
||||
## MLA Blackwell Backward
|
||||
|
||||
The sample also provides the feature of MLA backward(d=192, d_vo=128). To enable MLA backward, please specify `--d=192 --d_vo=128` when running the bwd sample.
|
||||
|
||||
`Sm100FmhaBwdMlaKernelTmaWarpSpecialized`is the main point for MLA backward. The MLA approach is slightly different from the original one to enable high performance with the MLA shape.
|
||||
|
||||
# MLA Inference for Blackwell
|
||||
|
||||
This sample provides code for fused multi-head latent attention inference in
|
||||
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
|
||||
It supports fp16, bf16, and fp8 input and output types.
|
||||
|
||||
To accomodate the large output accumulator due to the large latent head dimension,
|
||||
To accommodate the large output accumulator due to the large latent head dimension,
|
||||
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
|
||||
|
||||
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
|
||||
@ -55,6 +61,14 @@ The approach of this implementation is to reuse the selection logic of the colle
|
||||
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
|
||||
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
|
||||
|
||||
# Changes
|
||||
|
||||
* 4.1.0: Enhanced testing of variable sequence length; disabled B2B mode in MLA
|
||||
to simplify the sample, clarified that `fmha_gen` sample only supports head
|
||||
dim 128.
|
||||
|
||||
* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
@ -132,10 +132,68 @@ struct ResidualMask : NoMask {
|
||||
}
|
||||
};
|
||||
|
||||
struct ResidualMaskForBackward : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
template <class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// if the sequence length does not divide the tile size evenly
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||
}
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||
// the remaining elements out from softmax.
|
||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||
// issues as they are transparently taken care of by TMA and the
|
||||
// epilogue, if it is instantiated with predication support.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (! elem_less(pos, select<0,1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) The Q is at the beginning of the matrix
|
||||
// (2) The Q is at the end of the matrix
|
||||
template<bool kIsQBegin = true>
|
||||
struct CausalMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
static constexpr bool IsQBegin = kIsQBegin;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
@ -146,8 +204,14 @@ struct CausalMask : NoMask {
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
if constexpr (IsQBegin) {
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
} else {
|
||||
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
@ -156,9 +220,14 @@ struct CausalMask : NoMask {
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
if constexpr (IsQBegin) {
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
} else {
|
||||
const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape);
|
||||
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
@ -171,6 +240,47 @@ struct CausalMask : NoMask {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is the default setting.
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to set kIsQBegin=false
|
||||
|
||||
if constexpr (IsQBegin) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<bool kIsQBegin = true>
|
||||
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
|
||||
|
||||
using Base = CausalMask<kIsQBegin>;
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
@ -186,10 +296,16 @@ struct CausalMask : NoMask {
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
int offset_q = 0;
|
||||
if constexpr (!kIsQBegin) {
|
||||
offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
|
||||
if (masked) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
@ -200,22 +316,23 @@ struct CausalMask : NoMask {
|
||||
struct VariableLength {
|
||||
int max_length;
|
||||
int* cumulative_length = nullptr;
|
||||
int total_length = -1;
|
||||
|
||||
CUTE_HOST_DEVICE operator int() const {
|
||||
return max_length;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T> struct is_variable_length : std::false_type {};
|
||||
template<> struct is_variable_length<VariableLength> : std::true_type {};
|
||||
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
|
||||
template<class T> struct is_variable_length_impl : std::false_type {};
|
||||
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
|
||||
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
|
||||
|
||||
template<class Shape, class Idx>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length(Shape const& shape, Idx const& idx) {
|
||||
return transform_leaf(shape, [&](auto const& s) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
@ -230,7 +347,7 @@ constexpr auto
|
||||
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||
auto new_shape = apply_variable_length(shape, idx);
|
||||
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
return cute::make_tuple(c, s.cumulative_length[idx]);
|
||||
}
|
||||
else {
|
||||
@ -240,6 +357,30 @@ apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||
return cute::make_tuple(new_shape, new_coord);
|
||||
}
|
||||
|
||||
template<class Shape, class Coord>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
|
||||
auto idx = back(back(coord));
|
||||
auto result_shape = transform_leaf(shape, [&](auto const& s) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
return s;
|
||||
}
|
||||
});
|
||||
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
|
||||
if constexpr (is_variable_length_v<decltype(s)>) {
|
||||
return s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
return _0{};
|
||||
}
|
||||
});
|
||||
return cute::make_tuple(result_shape, result_offset);
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
namespace cute {
|
||||
|
||||
@ -42,7 +42,8 @@ template<
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE_ // Q, B
|
||||
class StrideLSE_, // Q, B
|
||||
class OrderLoadEpilogue = cute::false_type
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
@ -55,7 +56,11 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using StrideLSE = StrideLSE_;
|
||||
|
||||
using ElementOut = Element;
|
||||
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
using SmemLayoutO = SmemLayoutO_;
|
||||
@ -85,6 +90,19 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
// FMHA and MLA have different input ProblemShapes;
|
||||
// get problem_shape_O according to the input ProblemShape.
|
||||
template<class ProblemShape>
|
||||
CUTLASS_DEVICE static constexpr
|
||||
auto get_problem_shape_O (
|
||||
ProblemShape const& problem_shape) {
|
||||
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
|
||||
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
|
||||
} else {
|
||||
return select<0,2,3>(problem_shape);
|
||||
}
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
@ -93,7 +111,8 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
auto ptr_O = args.ptr_O;
|
||||
StrideO dO = args.dO;
|
||||
auto problem_shape_O = select<0,2,3>(problem_shape);
|
||||
|
||||
auto problem_shape_O = get_problem_shape_O(problem_shape);
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
@ -145,7 +164,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
int o0_index = 2 * get<0>(blk_coord);
|
||||
int o1_index = 2 * get<0>(blk_coord) + 1;
|
||||
|
||||
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
|
||||
// offset mode 0 by (max_length - real_length)
|
||||
// offset mode 3,1 by cumulative_length + real_length
|
||||
// the ptr is already offset by - max_length
|
||||
@ -200,6 +219,11 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
tma_store_wait<0>();
|
||||
|
||||
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
|
||||
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
}
|
||||
|
||||
pipeline.consumer_release(pipeline_release_state);
|
||||
++pipeline_release_state;
|
||||
|
||||
|
||||
@ -58,7 +58,9 @@ template<
|
||||
// and referes to the two softmax warps
|
||||
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
|
||||
// (1, 2, 1) means they sit side by side (best for small Q / large K)
|
||||
class ThreadShape = Shape<_2, _1, _1>
|
||||
class ThreadShape = Shape<_2, _1, _1>,
|
||||
// Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.
|
||||
class OrderLoadEpilogue = cute::false_type
|
||||
>
|
||||
struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
@ -76,7 +78,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
|
||||
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
static const int Alignment = 128 / sizeof_bits_v<Element>;
|
||||
@ -106,6 +108,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));
|
||||
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));
|
||||
|
||||
// Reuse shared memory for V and O.
|
||||
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
|
||||
struct TensorStorage {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
union {
|
||||
@ -168,9 +172,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
static const int TransactionBytesLoadKV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>), "K and V smem layouts must be of equal size");
|
||||
static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size");
|
||||
|
||||
using Load = Sm100FmhaLoadTmaWarpspecialized<
|
||||
Element, StrideQ, StrideK, StrideV,
|
||||
@ -525,7 +530,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
@ -663,7 +668,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
pipeline_c.producer_acquire(pipeline_c_producer_state);
|
||||
|
||||
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
|
||||
ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
|
||||
row_sum *= acc_scale;
|
||||
// row_sum = sum(reg_S)
|
||||
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
|
||||
@ -929,24 +934,31 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
|
||||
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO_i); j += 2) {
|
||||
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
|
||||
float2 out;
|
||||
cute::mul(out, scale_f32x2, in);
|
||||
tTMrO_i(j) = out.x;
|
||||
tTMrO_i(j+1) = out.y;
|
||||
|
||||
if (scale != 1.0f) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO_i); j += 2) {
|
||||
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
|
||||
float2 out;
|
||||
cute::mul(out, scale_f32x2, in);
|
||||
tTMrO_i(j) = out.x;
|
||||
tTMrO_i(j+1) = out.y;
|
||||
}
|
||||
}
|
||||
|
||||
copy_out(i);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
|
||||
template<
|
||||
class BlkCoord, class ProblemShape, class ParamsProblemShape,
|
||||
class TensorStorageEpi, class CollectiveEpilogue
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
correction(
|
||||
BlkCoord const& blk_coord,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorageEpi& shared_storage_epi,
|
||||
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
||||
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
||||
@ -1000,11 +1012,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
|
||||
|
||||
// e^(scale * (old_max - new_max)
|
||||
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
|
||||
pipeline_o.consumer_wait(pipeline_o_consumer_state);
|
||||
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O0));
|
||||
bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
|
||||
if (warp_do_correction) {
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O0));
|
||||
}
|
||||
|
||||
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
|
||||
++pipeline_s1_c_consumer_state;
|
||||
@ -1018,11 +1033,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
|
||||
|
||||
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
|
||||
pipeline_o.consumer_wait(pipeline_o_consumer_state);
|
||||
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O1));
|
||||
warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
|
||||
if (warp_do_correction) {
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O1));
|
||||
}
|
||||
|
||||
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
|
||||
++pipeline_s0_c_consumer_state;
|
||||
@ -1061,17 +1079,22 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// F2FP
|
||||
// store to smem
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1101,8 +1124,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1115,6 +1143,85 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
++pipeline_epi_producer_state;
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
class BlkCoord, class ProblemShape, class ParamsProblemShape,
|
||||
class TensorStorageEpi, class CollectiveEpilogue
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
correction_empty(
|
||||
BlkCoord const& blk_coord,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorageEpi& shared_storage_epi,
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||
CollectiveEpilogue& epilogue) {
|
||||
|
||||
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
|
||||
float lse = -INFINITY;
|
||||
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n")
|
||||
if (threadIdx.x % 128 == 0 && block0()) {
|
||||
DSHOW(sO);
|
||||
}
|
||||
#if 1
|
||||
|
||||
using ElementOut = typename CollectiveEpilogue::ElementOut;
|
||||
auto tiled_copy = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
|
||||
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
|
||||
sO.layout());
|
||||
|
||||
auto thr_copy = tiled_copy.get_slice(thread_idx);
|
||||
auto tOgO = thr_copy.partition_D(sO);
|
||||
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
|
||||
clear(tOrO);
|
||||
|
||||
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
|
||||
#endif
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
|
||||
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
@ -86,10 +86,10 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2;
|
||||
static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{});
|
||||
|
||||
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
|
||||
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
static const int Alignment = 128 / sizeof_bits_v<Element>;
|
||||
@ -187,7 +187,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
SmemLayoutQ, SmemLayoutK, SmemLayoutV,
|
||||
PipelineQ, PipelineKV, TileShape, Mask
|
||||
>;
|
||||
|
||||
|
||||
struct Arguments {
|
||||
typename Load::Arguments load;
|
||||
|
||||
@ -534,14 +534,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
// Each thread owns a single row
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||
|
||||
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
|
||||
@ -622,7 +622,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
|
||||
|
||||
const int kReleasePipeCount = 10; // must be multiple of 2
|
||||
|
||||
|
||||
order_s.wait();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -646,7 +646,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
}
|
||||
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
||||
|
||||
|
||||
|
||||
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
||||
order_s.arrive();
|
||||
}
|
||||
@ -672,7 +672,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
pipeline_c.producer_acquire(pipeline_c_producer_state);
|
||||
|
||||
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
|
||||
ElementQK acc_scale = (old_row_max == row_max_safe) ? 0.5f : 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
|
||||
row_sum *= acc_scale;
|
||||
// row_sum = sum(reg_S)
|
||||
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
|
||||
@ -700,7 +700,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
||||
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
||||
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
||||
|
||||
|
||||
row_sum = local_row_sum;
|
||||
|
||||
if (final_call) {
|
||||
@ -781,7 +781,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
template<class Vector, class GTensor, class CTensor, class Shape, class Epilogue>
|
||||
CUTLASS_DEVICE auto
|
||||
correction_epilogue(
|
||||
float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1,
|
||||
float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1,
|
||||
GTensor& gO, CTensor const& cO, Shape const& g_shape,
|
||||
Epilogue const& epilogue) {
|
||||
|
||||
@ -794,13 +794,13 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
||||
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
Tensor tOgO = mma.get_slice(0).partition_C(gO);
|
||||
|
||||
|
||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
@ -812,7 +812,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0);
|
||||
Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1);
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||
@ -831,7 +831,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
// loop:
|
||||
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
|
||||
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
|
||||
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
|
||||
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
|
||||
@ -841,10 +841,10 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
Tensor tTMrO0 = make_tensor<ElementPV>(shape(tTMEM_LOADcO));
|
||||
Tensor tTMrO1 = make_tensor<ElementPV>(shape(tTMEM_LOADcO));
|
||||
|
||||
|
||||
copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0);
|
||||
copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1);
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO0); j += 2) {
|
||||
float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1));
|
||||
@ -891,24 +891,24 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 32;
|
||||
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
|
||||
|
||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
|
||||
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
||||
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
||||
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
||||
@ -917,8 +917,8 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
float2 scale_f32x2 = make_float2(scale, scale);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
|
||||
|
||||
auto copy_in = [&](int i) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
@ -948,13 +948,16 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
}
|
||||
|
||||
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO_i); j += 2) {
|
||||
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
|
||||
float2 out;
|
||||
cute::mul(out, scale_f32x2, in);
|
||||
tTMrO_i(j) = out.x;
|
||||
tTMrO_i(j+1) = out.y;
|
||||
|
||||
if (scale != 1.0f) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO_i); j += 2) {
|
||||
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
|
||||
float2 out;
|
||||
cute::mul(out, scale_f32x2, in);
|
||||
tTMrO_i(j) = out.x;
|
||||
tTMrO_i(j+1) = out.y;
|
||||
}
|
||||
}
|
||||
|
||||
copy_out(i);
|
||||
@ -981,7 +984,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
||||
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
||||
|
||||
|
||||
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
@ -1019,11 +1022,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
|
||||
|
||||
// e^(scale * (old_max - new_max)
|
||||
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
float scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
|
||||
pipeline_o.consumer_wait(pipeline_o_consumer_state);
|
||||
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O0));
|
||||
bool warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
|
||||
if (warp_do_correction) {
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O0));
|
||||
}
|
||||
|
||||
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
|
||||
++pipeline_s1_c_consumer_state;
|
||||
@ -1037,11 +1043,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
|
||||
|
||||
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
scale = (tTMEM_LOADVrS(kIdxOldRowMax) == tTMEM_LOADVrS(kIdxNewRowMax)) ? 1.0f : ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
|
||||
|
||||
pipeline_o.consumer_wait(pipeline_o_consumer_state);
|
||||
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O1));
|
||||
warp_do_correction = __any_sync(0xFFFFFFFF, scale != 1.0f);
|
||||
if (warp_do_correction) {
|
||||
correction_rescale(scale, uint32_t(TmemAllocation::O1));
|
||||
}
|
||||
|
||||
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
|
||||
++pipeline_s0_c_consumer_state;
|
||||
|
||||
@ -170,8 +170,8 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
auto tSgQ = thr_mma_qk.partition_A(gQ);
|
||||
auto tScQ = thr_mma_qk.partition_A(cQ);
|
||||
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
|
||||
|
||||
auto tiled_copy_q = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},
|
||||
|
||||
@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
auto problem_shape_qk = problem_shape;
|
||||
|
||||
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
|
||||
IntProblemShape problem_shape_qk;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dQ) = get<0>(dQ);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_Q -= max_length_q * get<0>(dQ);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
|
||||
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_kv != nullptr) {
|
||||
int max_length_kv = get<1>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dK) = get<0>(dK);
|
||||
get<2,1>(dV) = get<0>(dV);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_K -= max_length_kv * get<0>(dK);
|
||||
ptr_V -= max_length_kv * get<0>(dV);
|
||||
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
|
||||
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
|
||||
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
|
||||
get<2>(problem_shape_qk) = get<2>(problem_shape);
|
||||
get<3>(problem_shape_qk) = get<3>(problem_shape);
|
||||
}
|
||||
} else {
|
||||
problem_shape_qk = problem_shape;
|
||||
}
|
||||
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
|
||||
@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
int q_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
q_offs_0 = max_length_q - get<0>(problem_shape);
|
||||
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
|
||||
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
int kv_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
int max_length = get<1>(params_problem_shape).max_length;
|
||||
kv_offs_0 = max_length - get<1>(problem_shape);
|
||||
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
|
||||
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,323 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class StrideQ,
|
||||
class StrideK,
|
||||
class StrideV,
|
||||
class CollectiveMmaQK,
|
||||
class CollectiveMmaPV,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class TensorStorage,
|
||||
class PipelineQ,
|
||||
class PipelineKV,
|
||||
class Mask,
|
||||
class TileShape,
|
||||
class OrderLoadEpilogue = cute::false_type
|
||||
>
|
||||
struct Sm100MlaFwdLoadTmaWarpspecialized {
|
||||
|
||||
using TileShapeQK = typename CollectiveMmaQK::TileShape;
|
||||
using TileShapePV = typename CollectiveMmaPV::TileShape;
|
||||
|
||||
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
|
||||
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
StrideQ dQ;
|
||||
const Element* ptr_K;
|
||||
StrideK dK;
|
||||
const Element* ptr_V;
|
||||
StrideV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
|
||||
auto ptr_Q = args.ptr_Q;
|
||||
auto ptr_K = args.ptr_K;
|
||||
auto ptr_V = args.ptr_V;
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
|
||||
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
|
||||
IntProblemShape problem_shape_qk;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
|
||||
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
|
||||
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
|
||||
get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
|
||||
get<3>(problem_shape_qk) = get<3>(problem_shape);
|
||||
}
|
||||
} else {
|
||||
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
|
||||
}
|
||||
|
||||
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
|
||||
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
|
||||
problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
ptr_Q, dQ,
|
||||
ptr_K, dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
|
||||
problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
ptr_K, dK, // never used, dummy
|
||||
ptr_V, select<1,0,2>(dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& storage,
|
||||
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
|
||||
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
|
||||
|
||||
BlkCoord blk_coord_q = blk_coord_in;
|
||||
BlkCoord blk_coord_kv = blk_coord_in;
|
||||
|
||||
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
|
||||
auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
// this one is only executed by one thread, no need to elect_one
|
||||
|
||||
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
|
||||
// two pipes: Q and KV
|
||||
// from Memory (prod) to TensorCore (cons)
|
||||
|
||||
// compute gQ, sQ
|
||||
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
|
||||
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
auto [tQgQ_qdl, tQsQ] = tma_partition(
|
||||
params.tma_load_q, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
|
||||
);
|
||||
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
|
||||
|
||||
// compute gK, sK
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
auto [tKgK_kdl, tKsK] = tma_partition(
|
||||
params.tma_load_k, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
|
||||
);
|
||||
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
|
||||
|
||||
// compute gV, sV
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
auto [tVgV_dkl, tVsV] = tma_partition(
|
||||
params.tma_load_v, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
|
||||
);
|
||||
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
|
||||
|
||||
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
|
||||
// As such, it needs to be transformed as
|
||||
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
|
||||
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
|
||||
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Q1
|
||||
int q0_index = 2 * get<0>(blk_coord_q);
|
||||
int q1_index = 2 * get<0>(blk_coord_q) + 1;
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
// K1
|
||||
int k_index = 0;
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Q2
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
|
||||
cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
}
|
||||
|
||||
// V1
|
||||
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
// loop:
|
||||
mask_tile_count -= 1;
|
||||
for (; mask_tile_count > 0; mask_tile_count -= 1) {
|
||||
|
||||
// Ki
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
|
||||
|
||||
// prefetch vi
|
||||
cute::prefetch(params.tma_load_v, tVgV(_, k_index));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Vi
|
||||
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
|
||||
|
||||
// prefetch ki+1
|
||||
if(mask_tile_count > 1) {
|
||||
cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));
|
||||
}
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
250
examples/77_blackwell_fmha/common/pipeline_mla.hpp
Normal file
250
examples/77_blackwell_fmha/common/pipeline_mla.hpp
Normal file
@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Support the producer to acquire specific bytes of data.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/pipeline/sm100_pipeline.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <
|
||||
int Stages_,
|
||||
class ClusterShape = Shape<int,int,_1>,
|
||||
class AtomThrShape_MNK_ = Shape<_1,_1,_1>
|
||||
>
|
||||
class PipelineTmaAsyncMla {
|
||||
|
||||
public:
|
||||
static constexpr uint32_t Stages = Stages_;
|
||||
using AtomThrShape_MNK = AtomThrShape_MNK_;
|
||||
|
||||
private:
|
||||
using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;
|
||||
|
||||
public:
|
||||
using FullBarrier = typename Impl::FullBarrier;
|
||||
using EmptyBarrier = typename Impl::EmptyBarrier;
|
||||
using ProducerBarrierType = typename Impl::ProducerBarrierType;
|
||||
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
|
||||
using PipelineState = typename Impl::PipelineState;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
using ThreadCategory = typename Impl::ThreadCategory;
|
||||
using Params = typename Impl::Params;
|
||||
|
||||
|
||||
using McastDirection = McastDirection;
|
||||
|
||||
// Helper function to initialize barriers
|
||||
static
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
if (warp_idx == params.initializing_warp) {
|
||||
// Barrier FULL and EMPTY init
|
||||
constexpr int producer_arv_cnt = 1;
|
||||
auto atom_thr_shape = AtomThrShape_MNK{};
|
||||
uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
|
||||
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;
|
||||
|
||||
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
|
||||
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
static
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {
|
||||
auto atom_thr_shape = AtomThrShape_MNK{};
|
||||
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
if (warp_idx == params.initializing_warp) {
|
||||
// Barrier FULL and EMPTY init
|
||||
constexpr int producer_arv_cnt = 1;
|
||||
uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?
|
||||
cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas
|
||||
cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas
|
||||
|
||||
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
|
||||
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
|
||||
// Calculate consumer mask
|
||||
if (params_.role == ThreadCategory::Consumer) {
|
||||
auto cluster_layout = make_layout(cluster_shape);
|
||||
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
|
||||
// Calculate consumer mask
|
||||
dim3 block_id_in_cluster = cute::block_id_in_cluster();
|
||||
auto cluster_layout = make_layout(cluster_shape);
|
||||
if (mcast_direction == McastDirection::kRow) {
|
||||
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
|
||||
}
|
||||
else {
|
||||
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
|
||||
CUTLASS_DEVICE
|
||||
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})
|
||||
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
|
||||
, params_(params)
|
||||
, empty_barrier_ptr_(&storage.empty_barrier_[0])
|
||||
, full_barrier_ptr_(&storage.full_barrier_[0]) {
|
||||
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
|
||||
init_barriers(storage, params_, cluster_shape);
|
||||
}
|
||||
|
||||
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
|
||||
init_masks(cluster_shape);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
|
||||
CUTLASS_DEVICE
|
||||
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})
|
||||
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
|
||||
, params_(params)
|
||||
, empty_barrier_ptr_(&storage.empty_barrier_[0])
|
||||
, full_barrier_ptr_(&storage.full_barrier_[0]) {
|
||||
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
|
||||
init_barriers(storage, params_, cluster_shape, mcast_direction);
|
||||
}
|
||||
|
||||
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
|
||||
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
|
||||
init_masks(cluster_shape, mcast_direction);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
|
||||
impl_.producer_acquire(state, barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {
|
||||
detail::pipeline_check_is_producer(params_.role);
|
||||
if (barrier_token != BarrierStatus::WaitDone) {
|
||||
empty_barrier_ptr_[stage].wait(phase);
|
||||
}
|
||||
|
||||
if (params_.is_leader) {
|
||||
full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
}
|
||||
|
||||
// Most likely you have elected more than one leader
|
||||
if (params_.is_leader && (threadIdx.x % 32 != 0)) {
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
|
||||
producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProducerBarrierType* producer_get_barrier(PipelineState state) {
|
||||
return impl_.producer_get_barrier(state);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
|
||||
impl_.consumer_wait(state, barrier_token);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void consumer_release(PipelineState state) {
|
||||
consumer_release(state.index(), false);
|
||||
}
|
||||
|
||||
private:
|
||||
Impl impl_;
|
||||
Params params_;
|
||||
EmptyBarrier *empty_barrier_ptr_;
|
||||
FullBarrier *full_barrier_ptr_;
|
||||
uint16_t block_id_mask_ = 0;
|
||||
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
|
||||
|
||||
// Consumer signalling Producer of completion
|
||||
// Ensures all blocks in the Same Row and Column get notifed.
|
||||
CUTLASS_DEVICE
|
||||
void consumer_release(uint32_t stage, uint32_t skip) {
|
||||
detail::pipeline_check_is_consumer(params_.role);
|
||||
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
|
||||
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
|
||||
if (!skip) {
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!skip) {
|
||||
if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {
|
||||
cutlass::arch::umma_arrive(smem_ptr);
|
||||
}
|
||||
else {
|
||||
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -39,6 +39,7 @@
|
||||
|
||||
#include "../device/fmha.hpp"
|
||||
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||
|
||||
@ -50,35 +51,74 @@ namespace cutlass::fmha::device {
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape, class Mask>
|
||||
template<
|
||||
class ProblemShape,
|
||||
class Element,
|
||||
class ElementAccumulator,
|
||||
class TileShape,
|
||||
bool IsMla,
|
||||
class Mask
|
||||
>
|
||||
class Sm100FmhaBwd {
|
||||
private:
|
||||
template <typename T>
|
||||
constexpr static auto to_bwd_shape(T shape) {
|
||||
if constexpr (IsMla) { // remove GQA mode
|
||||
constexpr int R = decltype(rank(shape))::value;
|
||||
auto HB = get<R-1>(shape);
|
||||
auto rest = take<0,R-1>(shape);
|
||||
return append(rest, make_shape(size<0>(HB), get<1>(HB)));
|
||||
}
|
||||
else {
|
||||
return shape;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr static auto to_bwd_stride(T stride) {
|
||||
if constexpr (IsMla) { // remove GQA mode
|
||||
constexpr int R = decltype(rank(stride))::value;
|
||||
auto HB = get<R-1>(stride);
|
||||
auto rest = take<0,R-1>(stride);
|
||||
if constexpr (is_same_v<remove_cv_t<decltype(get<0,0>(HB))>, _0>) {
|
||||
return append(rest, make_stride(get<0,1>(HB), get<1>(HB)));
|
||||
}
|
||||
else {
|
||||
return append(rest, make_stride(get<0,0>(HB), get<1>(HB)));
|
||||
}
|
||||
}
|
||||
else {
|
||||
return stride;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
// Q K D HB
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
// Q K D D_VO HB
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dV;
|
||||
|
||||
ElementAccumulator softmax_scale;
|
||||
|
||||
@ -86,15 +126,27 @@ public:
|
||||
};
|
||||
|
||||
using OperationSumOdO = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
|
||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<ProblemShape, Element, ElementAccumulator>
|
||||
>;
|
||||
using OperationConvert = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<ProblemShape, Element, ElementAccumulator>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
|
||||
using OperationNormal= cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
>;
|
||||
|
||||
using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{}));
|
||||
using OperationMla = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
|
||||
ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
>;
|
||||
|
||||
using Operation = std::conditional_t<IsMla, OperationMla, OperationNormal>;
|
||||
|
||||
using Kernel = typename Operation::Kernel;
|
||||
|
||||
struct Params {
|
||||
@ -113,15 +165,16 @@ private:
|
||||
ElementAccumulator* sum_odo = nullptr,
|
||||
ElementAccumulator* scaled_lse = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
auto log2_e = log2f(expf(1.0f));
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_size,
|
||||
args.problem_shape,
|
||||
args.ptr_O, args.stride_O,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
sum_odo, stride_sum_OdO,
|
||||
@ -133,16 +186,17 @@ private:
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_size,
|
||||
args.problem_shape,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
@ -152,21 +206,22 @@ private:
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_dQ = {}) {
|
||||
|
||||
return typename Operation::Arguments{
|
||||
args.problem_size,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ,
|
||||
to_bwd_shape(args.problem_shape),
|
||||
{ args.ptr_Q, to_bwd_stride(args.stride_Q),
|
||||
args.ptr_K, to_bwd_stride(args.stride_K),
|
||||
args.ptr_V, to_bwd_stride(args.stride_V),
|
||||
args.ptr_dO, to_bwd_stride(args.stride_dO),
|
||||
scaled_lse, to_bwd_stride(stride_scaled_lse),
|
||||
sum_OdO, to_bwd_stride(stride_sum_OdO),
|
||||
dQ_acc, to_bwd_stride(stride_dQ),
|
||||
args.softmax_scale },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
|
||||
args.ptr_dV, to_bwd_stride(args.stride_dV) },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
@ -199,10 +254,10 @@ public:
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
// OdO vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
@ -219,10 +274,10 @@ public:
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||
@ -248,10 +303,10 @@ public:
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
|
||||
@ -127,7 +127,11 @@ public:
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
if (args.is_fused_reduction && split_wave_aware > 1) {
|
||||
args.split_kv = std::min(split_wave_aware, static_cast<int>(sm_count/2));
|
||||
} else {
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
@ -273,11 +277,33 @@ public:
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
auto [H, K, D, B] = params.fmha_params.problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
|
||||
auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int);
|
||||
uint8_t* ws = reinterpret_cast<uint8_t*>(params.fmha_params.epilogue.ptr_lse_exchange_buff);
|
||||
result = cudaMemsetAsync(ws, 0, total_bytes, stream);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;;
|
||||
}
|
||||
}
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
@ -298,7 +324,7 @@ public:
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
|
||||
197
examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp
Normal file
197
examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp
Normal file
@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Swizzle Q tile and H tile to improve L2 cache hit rate,
|
||||
// and launch the longest main loop first to keep most SMs busy.
|
||||
|
||||
struct CausalIndividualTileScheduler {
|
||||
|
||||
static constexpr int TileQ = 16;
|
||||
static constexpr int TileH = 8;
|
||||
static constexpr int TileSize = TileQ * TileH;
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
int tile_max_q;
|
||||
FastDivmod divmod_tile_col;
|
||||
FastDivmod divmod_tile_size;
|
||||
FastDivmod divmod_tile_head;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalIndividualTileScheduler(Params const& params) : params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size));
|
||||
// gridDim.x must multiple of TileH
|
||||
const int tile_col_count = grid.x / TileH;
|
||||
const int tile_max_q = grid.y / TileQ * TileQ;
|
||||
return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
const int block_idx = blockIdx.y * gridDim.x + blockIdx.x;
|
||||
|
||||
int tile_idx, tile_tail;
|
||||
params.divmod_tile_size(tile_idx, tile_tail, block_idx);
|
||||
|
||||
int tile_row_idx, tile_col_idx;
|
||||
params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx);
|
||||
|
||||
int row_offset_in_tail, col_offset_in_tail;
|
||||
params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail);
|
||||
|
||||
const int row_idx = tile_row_idx * TileQ + row_offset_in_tail;
|
||||
const int col_idx = tile_col_idx * TileH + col_offset_in_tail;
|
||||
|
||||
// last q tile launch first
|
||||
if(blockIdx.y >= params.tile_max_q) {
|
||||
return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z)));
|
||||
}
|
||||
|
||||
return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z)));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Launch order: H Q B
|
||||
struct CausalPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_h;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) },
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, bidh;
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CausalPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user