Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 57e3cfb47a | |||
| e7e0adddac | |||
| 6a35b4d22f | |||
| 56f0718a97 | |||
| 76c96b0be3 | |||
| d98e7bf7ce | |||
| b6ccf34aef | |||
| 2288c0c901 | |||
| b2dd65dc86 | |||
| 496654bf2c | |||
| 9ca7e877b2 | |||
| a49a78ffef | |||
| 11cad1f67b | |||
| 931359cec1 | |||
| 42e7c546c4 | |||
| ec18e8043b | |||
| 5b76420d6a | |||
| 19772cd63e | |||
| 052afcd314 | |||
| 86cf63e2d4 | |||
| a267d47f9b | |||
| 9e6ab77d27 | |||
| d0eada85a3 | |||
| 23139309e9 | |||
| 6dd13d4278 | |||
| 3b054767b3 | |||
| 6fb5e667c1 | |||
| 6c891db9f6 | |||
| da47886e34 | |||
| 26b7450023 | |||
| a39cf6b511 | |||
| f09045d660 | |||
| 84a27b3926 | |||
| e093b4f691 | |||
| 664c4f7b3e | |||
| 0e026982ce | |||
| 9a9a579714 | |||
| 51d730b8be | |||
| 6c0c8b7484 |
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
23
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -1,23 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG]"
|
||||
labels: "? - Needs Triage, bug"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
38
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: Bug Report
|
||||
description: Create a bug report to help us improve CUTLASS
|
||||
title: "[BUG] "
|
||||
labels: ["? - Needs Triage", "bug"]
|
||||
assignees: []
|
||||
|
||||
body:
|
||||
- type: dropdown
|
||||
id: component
|
||||
attributes:
|
||||
label: Which component has the problem?
|
||||
options:
|
||||
- CuTe DSL
|
||||
- CUTLASS C++
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: bug-report
|
||||
attributes:
|
||||
label: Bug Report
|
||||
description: Please fill out all sections below
|
||||
value: |
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps/Code to reproduce bug**
|
||||
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Environment details (please complete the following information):**
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
validations:
|
||||
required: true
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for CUTLASS
|
||||
title: "[FEA]"
|
||||
labels: "? - Needs Triage, feature request"
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
35
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
35
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
name: Feature Request
|
||||
description: Suggest an idea for CUTLASS
|
||||
title: "[FEA] "
|
||||
labels: ["? - Needs Triage", "feature request"]
|
||||
assignees: []
|
||||
|
||||
body:
|
||||
- type: dropdown
|
||||
id: component
|
||||
attributes:
|
||||
label: Which component requires the feature?
|
||||
options:
|
||||
- CuTe DSL
|
||||
- CUTLASS C++
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: feature-request
|
||||
attributes:
|
||||
label: Feature Request
|
||||
description: Please fill out all sections below
|
||||
value: |
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context, code examples, or references to existing implementations about the feature request here.
|
||||
validations:
|
||||
required: true
|
||||
51
.github/workflows/auto-label-issues.yml
vendored
Normal file
51
.github/workflows/auto-label-issues.yml
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
name: Auto Label Issues
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
add-labels:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- name: Add component label
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const issue = context.payload.issue;
|
||||
const body = issue.body || '';
|
||||
|
||||
// Parse the issue body to find the component selection
|
||||
// GitHub renders dropdown selections as "### {label}\n\n{selection}"
|
||||
// Check for both bug report and feature request dropdown labels
|
||||
const bugComponentMatch = body.match(/### Which component has the problem\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
|
||||
const featureComponentMatch = body.match(/### Which component requires the feature\?\s*\n\s*\n\s*(.+?)(?:\n|$)/);
|
||||
|
||||
const componentMatch = bugComponentMatch || featureComponentMatch;
|
||||
|
||||
if (componentMatch) {
|
||||
const component = componentMatch[1].trim();
|
||||
let label = '';
|
||||
|
||||
// Map component selections to labels
|
||||
switch(component) {
|
||||
case 'CuTe DSL':
|
||||
label = 'CuTe DSL';
|
||||
break;
|
||||
case 'CUTLASS C++':
|
||||
label = 'CUTLASS C++';
|
||||
break;
|
||||
}
|
||||
|
||||
if (label) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issue.number,
|
||||
labels: [label]
|
||||
});
|
||||
console.log(`Added label: ${label}`);
|
||||
}
|
||||
}
|
||||
2
.github/workflows/blossom-ci.yml
vendored
2
.github/workflows/blossom-ci.yml
vendored
@ -55,7 +55,7 @@ jobs:
|
||||
if: |
|
||||
(startsWith(github.event.comment.body, '/bot run') ||
|
||||
startsWith(github.event.comment.body, '/bot kill')) && contains(
|
||||
fromJson('["zekunf-nv"]'),
|
||||
fromJson('["nv-fastkernels-cicd", "zekunf-nv", "hwu36", "IonThruster", "thakkarV", "d-k-b", "mihir-awatramani", "fengxie", "vickiw973", "Junkai-Wu", "brandon-yujie-sun", "lijingticy22", "hongw-nv", "vikgupta-nv", "IwakuraRein", "depaulmillz", "jackkosaian", "itramble", "ccecka", "sxtyzhangzk", "hbarclay", "yzhaiustc", "x86vk", "sklevtsov-nvidia", "ANIKET-SHIVAM", "Shreya-gaur", "azhurkevich", "serifyesil", "richardmcai", "lsyyy666", "Ethan-Yan27", "XiaoSong9905", "shdetect", "keithzzzzz"]'),
|
||||
github.actor)
|
||||
steps:
|
||||
- name: Check if comment is issued by authorized person
|
||||
|
||||
95
CHANGELOG.md
95
CHANGELOG.md
@ -2,6 +2,97 @@
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15)
|
||||
|
||||
### CuTe DSL
|
||||
* More Python versions are now supported for both x86-64 and aarch64, including
|
||||
- Python 3.10, 3.11, 3.12, and 3.13
|
||||
* Added new example and updated notebook to get started with CuTe DSL
|
||||
- [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py)
|
||||
- Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb)
|
||||
+ Added a section for introducing the broadcast
|
||||
* API updates
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
* Bug fixings and improvements
|
||||
- Fixed ``cute.print_tensor`` for coordinate tensor
|
||||
- Fixed `cute.print` for tuple of layouts
|
||||
- Fixed frozen object is not properly updated after fully assigned in dynamic control flow
|
||||
- Fixed assign tuple/list element in a dynamic control flow may cause compilation failure
|
||||
- Improved error message when CUDA context is not initialized
|
||||
- Improved docstring of congruent and weakly_congruent
|
||||
|
||||
### CUTLASS C++
|
||||
* Support for Blackwell SM103 kernels for B300 GPUs.
|
||||
- Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp)
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture:
|
||||
- [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/).
|
||||
- [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm).
|
||||
* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM
|
||||
- Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/).
|
||||
* Support for Blackwell SM121 kernels for DGX Spark GPUs.
|
||||
- Share the major codes with Blackwell SM120 kernels.
|
||||
* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario.
|
||||
- Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md).
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add fused reduction kernel support for cutlass MLA.
|
||||
- Add softmax skip correction.
|
||||
- Support for GQA in FMHA backward kernel.
|
||||
- Fix an issue where `get_unmasked_trip_count` may return a negative value.
|
||||
- Fix an issue where mbarriers are initialized with a zero arrival count.
|
||||
- Fix a corner case issue where the sequence length of q is not a multiple of tile_q.
|
||||
- Remove tma padding for forward kernel inputs.
|
||||
* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome.
|
||||
* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell
|
||||
- On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/).
|
||||
- On Hopper, add K major scale factor support for SM90 blockwise kernels.
|
||||
- On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size.
|
||||
- On Hopper, grouped version supports the case when k = 0.
|
||||
* Support for Blackwell SM100 fp4 gemv kernels.
|
||||
- Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h).
|
||||
- Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/)
|
||||
* Support for Blackwell SM100 legacy mixed input GEMM kernels.
|
||||
- Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp).
|
||||
- Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp).
|
||||
- Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/).
|
||||
* Support for Blackwell SM100 cpasync kernel.
|
||||
- Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp).
|
||||
- Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp).
|
||||
* Support Blackwell SM120 mixed input blockscaled grouped GEMM.
|
||||
* Instantiating more Blackwell kernels in profiler.
|
||||
- Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations.
|
||||
- To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels.
|
||||
- Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md).
|
||||
* Fix some profiler issues:
|
||||
- Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line.
|
||||
- Fix some no output and timeout issues.
|
||||
- Fix Pingpong Blockwise Hopper library generation.
|
||||
* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110.
|
||||
- For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs.
|
||||
- For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid.
|
||||
* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface.
|
||||
- Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`.
|
||||
- Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter.
|
||||
- Added some support for running SM100 kernels via the Python interface.
|
||||
* CuTe changes:
|
||||
- Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/).
|
||||
- Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support.
|
||||
- Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels.
|
||||
- Support fp16 accmulator for sm89 fp8 mma.
|
||||
- Shorten `nullspace` implementation.
|
||||
- Isolate and comment on `cosize` hacks.
|
||||
- Important documentation correction: `E<0,1> == 1@0@1`.
|
||||
* Fix some kernel issues:
|
||||
- Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers.
|
||||
- Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel.
|
||||
* Add following unit tests:
|
||||
- [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu)
|
||||
- [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu)
|
||||
- [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu)
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 13.0U1.
|
||||
|
||||
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
|
||||
|
||||
### CuTe DSL
|
||||
@ -10,7 +101,7 @@
|
||||
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
|
||||
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
|
||||
* API updates
|
||||
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
|
||||
### CUTLASS C++
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
@ -58,7 +149,7 @@
|
||||
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
|
||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||
* API updates
|
||||
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
|
||||
### CUTLASS C++
|
||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||
|
||||
@ -175,13 +175,25 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a 121 121a)
|
||||
if (CUDA_VERSION VERSION_LESS 13.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
|
||||
else()
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f 121f 103a 103f)
|
||||
if (CUDA_VERSION VERSION_LESS 13.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
|
||||
else()
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110f)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
@ -288,17 +300,49 @@ if (KERNEL_FILTER_FILE)
|
||||
set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
|
||||
get_filename_component(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" ABSOLUTE)
|
||||
set(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" CACHE STRING "HEURISTICS FILE FULL PATH" FORCE)
|
||||
endif()
|
||||
|
||||
set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list")
|
||||
|
||||
if(KERNEL_FILTER_FILE)
|
||||
message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}")
|
||||
endif()
|
||||
|
||||
if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
|
||||
message(STATUS "Full path of heuristics problems file: ${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}")
|
||||
if(DEFINED CUTLASS_NVMMH_URL)
|
||||
message(STATUS "CUTLASS_NVVMH_URL is set. Fetching dependency")
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
nvmmh
|
||||
URL ${CUTLASS_NVMMH_URL}
|
||||
)
|
||||
FetchContent_MakeAvailable(nvmmh)
|
||||
FetchContent_GetProperties(nvmmh SOURCE_DIR nvmmh_dir)
|
||||
set(CUTLASS_NVMMH_PATH "${nvmmh_dir}")
|
||||
endif()
|
||||
|
||||
if(DEFINED CUTLASS_NVMMH_PATH)
|
||||
message(STATUS "CUTLASS_NVMMH_PATH is set. Using package at: ${CUTLASS_NVMMH_PATH}")
|
||||
|
||||
set(CUTLASS_NVMMH_PY_DIR "${CUTLASS_NVMMH_PATH}/python/")
|
||||
set(ENV{CUTLASS_NVMMH_SO_PATH} "${CUTLASS_NVMMH_PATH}/lib/libnvMatmulHeuristics.so")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.")
|
||||
set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).")
|
||||
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
|
||||
set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 and SM100 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.")
|
||||
|
||||
if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE)
|
||||
message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
@ -350,6 +394,10 @@ if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 110f)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
|
||||
|
||||
#
|
||||
@ -428,8 +476,6 @@ endif()
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
|
||||
|
||||
# Warnings-as-error exceptions and warning suppressions for Clang builds
|
||||
if (CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
@ -704,9 +750,16 @@ target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
|
||||
)
|
||||
|
||||
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
|
||||
)
|
||||
endif()
|
||||
|
||||
install(
|
||||
DIRECTORY
|
||||
${CUTLASS_INCLUDE_DIR}/
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
# Changelog for CuTe DSL API changes
|
||||
|
||||
## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16)
|
||||
|
||||
* for loop
|
||||
- Python built-in ``range`` now always generates IR and executes at runtime
|
||||
- ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control
|
||||
- Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range``
|
||||
- **Experimental** Added ``pipelining`` control for compiler generated software pipeline code
|
||||
* while/if
|
||||
- ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate
|
||||
- Deprecated ``cutlass.dynamic_expr``, please remove it
|
||||
* Rename mbarrier functions to reduce ambiguity
|
||||
* Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier`
|
||||
* Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional.
|
||||
* Introduce `cutlass.cute.arch.get_dyn_smem_size` api to get runtime dynamic shared memory size.
|
||||
* Various API Support for SM100 BlockScaled Gemm
|
||||
- Introduce BlockScaled MmaOps in [tcgen05/mma.py]([https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py]), and provide a `make_blockscaled_trivial_tiled_mma` function in [blackwell_helpers.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blackwell_helpers.py) to help construct a BlockScaled TiledMma.
|
||||
- Introduce S2T CopyOps in [tcgen05/copy.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py).
|
||||
- Introduce BlockScaled layout utilities in [blockscaled_layout.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blockscaled_layout.py) for creating the required scale factor layouts in global memory, shared memory and tensor memory.
|
||||
* `cutlass.cute.compile` now supports compilation options. Refer to [JIT compilation options](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.html) for more details.
|
||||
* `cutlass.cute.testing.assert_` now works for device JIT function. Specify `--enable-device-assertions` as compilation option to enable.
|
||||
* `cutlass.cute.make_tiled_copy` is now deprecated. Please use `cutlass.cute.make_tiled_copy_tv` instead.
|
||||
* Shared memory capacity query
|
||||
- Introduce `cutlass.utils.get_smem_capacity_in_bytes` for querying the shared memory capacity.
|
||||
- `<arch>_utils.SMEM_CAPACITY["<arch_str>"]` is now deprecated.
|
||||
|
||||
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
|
||||
|
||||
* Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||
120
README.md
120
README.md
@ -1,9 +1,9 @@
|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 4.1.0
|
||||
# CUTLASS 4.2.0
|
||||
|
||||
_CUTLASS 4.1.0 - July 2025_
|
||||
_CUTLASS 4.2.0 - Sept 2025_
|
||||
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
@ -27,14 +27,14 @@ native support of such data types) across NVIDIA's Volta, Turing, Ampere, Ada, H
|
||||
|
||||
To this rich ecosystem of C++ based kernel programming abstractions, CUTLASS 4 adds CUTLASS DSLs. These are Python native interfaces for writing high-performance CUDA kernels based on core CUTLASS and CuTe concepts without any performance compromises. This allows for a much smoother learning curve, orders of magnitude faster compile times, native integration with DL frameworks without writing glue code, and much more intuitive metaprogramming that does not require deep C++ expertise.
|
||||
|
||||
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions — exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
|
||||
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions -- exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
|
||||
|
||||
CuTe DSL demonstrates optimal matrix multiply and other linear algebra operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Ampere, Hopper, and Blackwell architectures.
|
||||
|
||||
We believe it will become an indispensable tool for students, researchers, and performance
|
||||
engineers alike — flattening the learning curve of GPU programming, rapidly prototyping kernel
|
||||
engineers alike -- flattening the learning curve of GPU programming, rapidly prototyping kernel
|
||||
designs, and bringing optimized solutions into production.
|
||||
|
||||
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
|
||||
@ -43,40 +43,94 @@ To get started quickly - please refer :
|
||||
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
|
||||
|
||||
# What's New in CUTLASS 4.1
|
||||
# What's New in CUTLASS 4.2
|
||||
|
||||
## CuTe DSL
|
||||
* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems!
|
||||
* More examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py)
|
||||
- [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
|
||||
* More Python versions are now supported for both x86-64 and aarch64, including
|
||||
- Python 3.10, 3.11, 3.12, and 3.13
|
||||
* Added new example and updated notebook to get started with CuTe DSL
|
||||
- [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py)
|
||||
- Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb)
|
||||
+ Added a section for introducing the broadcast
|
||||
* API updates
|
||||
- Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details
|
||||
- Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details
|
||||
* Bug fixings and improvements
|
||||
- Fixed ``cute.print_tensor`` for coordinate tensor
|
||||
- Fixed `cute.print` for tuple of layouts
|
||||
- Fixed frozen object is not properly updated after fully assigned in dynamic control flow
|
||||
- Fixed assign tuple/list element in a dynamic control flow may cause compilation failure
|
||||
- Improved error message when CUDA context is not initialized
|
||||
- Improved docstring of congruent and weakly_congruent
|
||||
|
||||
## CUTLASS C++
|
||||
* Support for Blackwell SM103 kernels for B300 GPUs.
|
||||
- Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp)
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture:
|
||||
- [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/).
|
||||
- [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm).
|
||||
* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM
|
||||
- Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/).
|
||||
* Support for Blackwell SM121 kernels for DGX Spark GPUs.
|
||||
- Share the major codes with Blackwell SM120 kernels.
|
||||
* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario.
|
||||
- Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md).
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Add variable sequence length support for FMHA Backward kernel.
|
||||
- Add varlen test support to Backward runner.
|
||||
- Codes support empty batch sequences.
|
||||
* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays.
|
||||
- Add fused reduction kernel support for cutlass MLA.
|
||||
- Add softmax skip correction.
|
||||
- Support for GQA in FMHA backward kernel.
|
||||
- Fix an issue where `get_unmasked_trip_count` may return a negative value.
|
||||
- Fix an issue where mbarriers are initialized with a zero arrival count.
|
||||
- Fix a corner case issue where the sequence length of q is not a multiple of tile_q.
|
||||
- Remove tma padding for forward kernel inputs.
|
||||
* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome.
|
||||
* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell
|
||||
- On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/).
|
||||
- On Hopper, add K major scale factor support for SM90 blockwise kernels.
|
||||
- On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size.
|
||||
- On Hopper, grouped version supports the case when k = 0.
|
||||
* Support for Blackwell SM100 fp4 gemv kernels.
|
||||
- Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h).
|
||||
- Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/)
|
||||
* Support for Blackwell SM100 legacy mixed input GEMM kernels.
|
||||
- Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp).
|
||||
- Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp).
|
||||
- Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/).
|
||||
* Support for Blackwell SM100 cpasync kernel.
|
||||
- Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp).
|
||||
- Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp).
|
||||
* Support Blackwell SM120 mixed input blockscaled grouped GEMM.
|
||||
* Instantiating more Blackwell kernels in profiler.
|
||||
- Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations.
|
||||
- To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels.
|
||||
- Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md).
|
||||
* Fix some profiler issues:
|
||||
- Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line.
|
||||
- Fix some no output and timeout issues.
|
||||
- Fix Pingpong Blockwise Hopper library generation.
|
||||
* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110.
|
||||
- For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs.
|
||||
- For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid.
|
||||
* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface.
|
||||
- Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`.
|
||||
- Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter.
|
||||
- Added some support for running SM100 kernels via the Python interface.
|
||||
* CuTe changes:
|
||||
- Rewrite ArithTuple and ScaledBasis for robustness and clarity.
|
||||
- Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX.
|
||||
- Factor out `print_latex` and friends and rewrite.
|
||||
- Factor out `print_svg` and friends and rewrite.
|
||||
* Support Blackwell SM100 SIMT packed fp32x2 kernels.
|
||||
* Support residual add for implicit gemm kernels.
|
||||
* Various fixes for CUTLASS C++ Python interface's EVT tracer:
|
||||
- Add verifier for sm90 to report the invalid input.
|
||||
- When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges.
|
||||
- Register operations of tanh, sigmoid, exp, gelu to the python ast frontend.
|
||||
- Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback.
|
||||
* Fix profiler bugs in exhaustive perf search.
|
||||
- Fix incorrect cluster shape output issue when doing exhaustive search.
|
||||
- Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders.
|
||||
* Fix some profiler issues.
|
||||
- Complete the reference for Blackwell blockwise gemm kernels.
|
||||
- Fix incorrect regex logic for L1 test.
|
||||
- Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/).
|
||||
- Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support.
|
||||
- Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels.
|
||||
- Support fp16 accmulator for sm89 fp8 mma.
|
||||
- Shorten `nullspace` implementation.
|
||||
- Isolate and comment on `cosize` hacks.
|
||||
- Important documentation correction: `E<0,1> == 1@0@1`.
|
||||
* Fix some kernel issues:
|
||||
- Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers.
|
||||
- Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel.
|
||||
* Add following unit tests:
|
||||
- [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu)
|
||||
- [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu)
|
||||
- [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu)
|
||||
|
||||
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
@ -170,7 +224,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|
||||
|NVIDIA GeForce RTX 50x0 series |12.0|12.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
@ -202,7 +256,7 @@ cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs (SM120). As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
|
||||
@ -65,7 +65,12 @@ endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
|
||||
set(PROFILER_ARCH_LIST 100a 100f 103a 120a 120f 121a)
|
||||
if (CUDA_VERSION VERSION_LESS 13.0)
|
||||
list(APPEND PROFILER_ARCH_LIST 101a 101f)
|
||||
else()
|
||||
list(APPEND PROFILER_ARCH_LIST 110a 110f)
|
||||
endif()
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")
|
||||
|
||||
@ -45,7 +45,7 @@
|
||||
cutlass::half_t
|
||||
|
||||
This is a numeric type implementing IEEE half-precision quantities. It is functional in host
|
||||
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables harware-accelerated
|
||||
and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables hardware-accelerated
|
||||
numeric conversion on x86-64 CPUs support F16C extensions. In device code, all available
|
||||
hardware is used to implement conversion and numeric operations.
|
||||
|
||||
|
||||
@ -243,10 +243,11 @@ cudaError_t run_batched_gemm(bool use_array) {
|
||||
const char* gemm_desc = use_array ? "array" : "strided batched";
|
||||
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
|
||||
|
||||
// Arbitrary problem size
|
||||
// Arbitrary matrix shape
|
||||
int const m = 520;
|
||||
int const n = 219;
|
||||
int const k = 129;
|
||||
|
||||
int const batch_count = 17;
|
||||
|
||||
// A, B are non-transpose, column major
|
||||
|
||||
@ -659,7 +659,7 @@ struct Testbed {
|
||||
}
|
||||
|
||||
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
|
||||
int64_t bytes = cutlass::bits_to_bytes(
|
||||
int64_t bytes = cutlass::bits_to_bytes<int64_t>(
|
||||
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
|
||||
options.problem_size.m() * options.problem_size.n());
|
||||
|
||||
|
||||
@ -33,8 +33,8 @@
|
||||
computing reference permutations of 4/5D tensors when source data is column-major.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cuda/std/cassert>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
@ -40,14 +40,12 @@
|
||||
Note that in general the fragment passed to the OutputOp could
|
||||
span multiple rows but it does not happen with the configurations we have
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
@ -42,12 +42,10 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
@ -38,10 +38,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
@ -37,12 +37,10 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/std/cassert>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include CUDA_STD_HEADER(cassert)
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
@ -26,7 +26,9 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 90a)
|
||||
cutlass_example_add_executable(
|
||||
65_distributed_gemm
|
||||
65_distributed_gemm.cu
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -129,7 +129,7 @@ using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_confi
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
@ -132,12 +132,12 @@ constexpr int ScaleGranularityK = 128;
|
||||
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
|
||||
@ -145,7 +145,7 @@ using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularity
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
@ -402,12 +402,37 @@ void initialize(const OptionType &options) {
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
// If the current group's matrix has size 0, set the pointer to nullptr
|
||||
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
|
||||
ptr_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
|
||||
ptr_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
|
||||
ptr_C_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
|
||||
ptr_D_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
|
||||
ptr_blockscale_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
|
||||
ptr_blockscale_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
}
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
@ -546,10 +571,10 @@ bool verify(const OptionType &options) {
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
@ -598,11 +623,7 @@ bool verify(const OptionType &options) {
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // Aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
@ -639,6 +660,24 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::string raster = "Heuristic";
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
@ -671,8 +710,7 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
@ -686,25 +724,6 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
|
||||
@ -132,8 +132,7 @@ using ElementCompute = float; // E
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
@ -142,13 +141,13 @@ static constexpr int ScaleGranularityK = 128;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::GMMA::Major::MN, cute::GMMA::Major::K>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
@ -407,12 +406,37 @@ void initialize(const OptionType &options) {
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
// If the current group's matrix has size 0, set the pointer to nullptr
|
||||
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
|
||||
ptr_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
|
||||
ptr_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
|
||||
ptr_C_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
|
||||
ptr_D_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_A.at(i) == offset_blockscale_A.at(i + 1)) {
|
||||
ptr_blockscale_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_blockscale_B.at(i) == offset_blockscale_B.at(i + 1)) {
|
||||
ptr_blockscale_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
}
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
@ -551,10 +575,10 @@ bool verify(const OptionType &options) {
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
std::cout << " Running host reference kernel - may run for a while for large problems." << std::endl;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
@ -637,10 +661,27 @@ bool verify(const OptionType &options) {
|
||||
template <typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::string raster = "Heuristic";
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
@ -695,27 +736,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
ScaleMsPerTile,
|
||||
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
std::cout << " GBPS: " << result.gbps << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return 0;
|
||||
@ -766,8 +790,8 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running tests with host problem shapes:" << std::endl;
|
||||
run(options, true);
|
||||
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run(options, false);
|
||||
|
||||
|
||||
@ -44,6 +44,9 @@ set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0)
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_K_16B_ALIGNED --m=256 --n=512 --k=960 --groups=10 --iterations=0)
|
||||
set(TEST_K_16B_ALIGNED_LARGE_GROUP --m=256 --n=512 --k=960 --groups=512 --iterations=0)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
|
||||
@ -58,6 +61,8 @@ cutlass_example_add_executable(
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_K_16B_ALIGNED
|
||||
TEST_K_16B_ALIGNED_LARGE_GROUP
|
||||
)
|
||||
|
||||
# MSVC will fail to compile this example with the following error:
|
||||
|
||||
@ -111,14 +111,14 @@ struct Options {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
|
||||
if (m < 0) {
|
||||
m = m_alignment * (rand() % (64 * alignment / m_alignment));
|
||||
}
|
||||
if (n < 1) {
|
||||
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
|
||||
if (n < 0) {
|
||||
n = n_alignment * (rand() % (64 * alignment / n_alignment));
|
||||
}
|
||||
if (k < 1) {
|
||||
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
|
||||
if (k < 0) {
|
||||
k = k_alignment * (rand() % (32 * alignment / k_alignment));
|
||||
}
|
||||
problem_sizes_after_alignment_host.push_back({m, n, k});
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
|
||||
@ -454,11 +454,12 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -640,11 +640,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -33,7 +33,7 @@ set(TEST_SWIZZLE_2 --swizzle=2)
|
||||
set(TEST_SWIZZLE_5 --swizzle=5)
|
||||
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp16_gemm
|
||||
70_blackwell_fp16_gemm.cu
|
||||
|
||||
@ -449,9 +449,9 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
|
||||
cutlass_example_add_executable(
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
71_blackwell_gemm_with_collective_builder.cu
|
||||
|
||||
@ -116,7 +116,7 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
@ -511,10 +511,10 @@ int main(int argc, char const **args) {
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -566,8 +566,8 @@ int main(int argc, char const **args) {
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
@ -512,8 +512,8 @@ int main(int argc, char const **args) {
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
|
||||
cutlass_example_add_executable(
|
||||
72a_blackwell_nvfp4_bf16_gemm
|
||||
72a_blackwell_nvfp4_bf16_gemm.cu
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
|
||||
cutlass_example_add_executable(
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
blackwell_gemm_preferred_cluster.cu
|
||||
|
||||
@ -513,7 +513,7 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -29,9 +29,9 @@
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
74_blackwell_gemm_streamk
|
||||
blackwell_gemm_streamk.cu
|
||||
if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f")
|
||||
cutlass_example_add_executable(
|
||||
74_blackwell_gemm_streamk
|
||||
blackwell_gemm_streamk.cu
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -556,10 +556,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -762,9 +762,8 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -138,8 +138,7 @@ using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor
|
||||
|
||||
// Core kernel configurations
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
// Runtime Cluster Shape
|
||||
@ -159,7 +158,7 @@ struct MMA2SMConfig {
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
typename MMA1SMConfig::MmaTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -169,7 +168,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
@ -187,7 +186,7 @@ using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
|
||||
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
@ -197,13 +196,13 @@ using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::Collective
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>,
|
||||
typename MMA2SMConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal<
|
||||
@ -233,7 +232,7 @@ using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFA> layout_SFB_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
@ -897,9 +896,8 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
if(CUTLASS_NVCC_ARCHS STREQUAL "100a")
|
||||
cutlass_example_add_executable(
|
||||
75_blackwell_grouped_gemm
|
||||
75_blackwell_grouped_gemm.cu
|
||||
|
||||
@ -504,10 +504,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -504,10 +504,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -500,10 +500,19 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
if (__CUDACC_VER_MAJOR__ < 13) {
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if ((props.major != 10 || props.major != 11) && props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -126,6 +126,7 @@ struct Options {
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool causal_q_begin = true;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
@ -266,6 +267,8 @@ struct Options {
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
std::string causal_type;
|
||||
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
@ -275,6 +278,11 @@ struct Options {
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
if(causal_type == "qend") {
|
||||
causal_q_begin = false;
|
||||
} else {
|
||||
causal_q_begin = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
@ -313,6 +321,7 @@ struct Options {
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --causal-type=<qbegin|qend> Causal mask type\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
@ -410,16 +419,16 @@ struct FwdRunner {
|
||||
using ElementAccumulatorPV = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
// Q K D (B H)
|
||||
// Q K D ((H_R, H_K) B)
|
||||
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
|
||||
|
||||
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
|
||||
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
|
||||
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D ((H_R, H_K), B)
|
||||
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D ((H_R, H_K), B)
|
||||
using StrideV = StrideK;
|
||||
using StrideO = StrideQ;
|
||||
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
|
||||
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q ((H_R, H_K), B)
|
||||
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
|
||||
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||
@ -602,8 +611,8 @@ struct FwdRunner {
|
||||
|
||||
ProblemShapeType problem_size_for_launch;
|
||||
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
|
||||
get<2>(problem_size_for_launch) = get<2>(problem_size);
|
||||
get<3>(problem_size_for_launch) = get<3>(problem_size);
|
||||
|
||||
@ -660,9 +669,9 @@ struct FwdRunner {
|
||||
}
|
||||
|
||||
auto buffer_init_fn = [&](auto& buffer) {
|
||||
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_Q.reset(size(shape_QO));
|
||||
buffer.block_K.reset(size(shape_KV));
|
||||
buffer.block_V.reset(size(shape_KV));
|
||||
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_LSE.reset(size(shape_LSE));
|
||||
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
@ -1078,7 +1087,11 @@ int main_single(int argc, char const **args) {
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
if(options.causal_q_begin) {
|
||||
fn(CausalMask{});
|
||||
} else {
|
||||
fn(CausalMask<false>{});
|
||||
}
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
|
||||
@ -183,6 +183,9 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
@ -298,6 +301,7 @@ struct Options {
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
|
||||
@ -405,25 +409,24 @@ struct BwdRunner {
|
||||
#endif
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Q K D (H B)
|
||||
// Q K D D_VO ((H_R, H_K) B)
|
||||
using ProblemShape = std::conditional_t<
|
||||
kIsVarlen,
|
||||
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<int, int>>,
|
||||
cute::tuple<int, int, int, int, cute::tuple<int, int>>
|
||||
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
|
||||
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
|
||||
>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
using StrideQ = TensorStride;
|
||||
using StrideK = TensorStride;
|
||||
using StrideV = TensorStride;
|
||||
using StrideO = TensorStride;
|
||||
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
|
||||
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
|
||||
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
|
||||
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
|
||||
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
|
||||
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
|
||||
|
||||
// Backwards specific
|
||||
using StrideDQ = TensorStride;
|
||||
using StrideDK = TensorStride;
|
||||
using StrideDV = TensorStride;
|
||||
using StrideDO = TensorStride;
|
||||
using StrideDQ = StrideQ;
|
||||
using StrideDK = StrideK;
|
||||
using StrideDV = StrideV;
|
||||
using StrideDO = StrideO;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -468,43 +471,15 @@ struct BwdRunner {
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
// keep going here! (this might be better in cursor)
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
|
||||
select<0,2,4>(problem_shape),
|
||||
stride_dQ);
|
||||
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
|
||||
select<1,2,4>(problem_shape),
|
||||
stride_dK);
|
||||
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
|
||||
select<1,3,4>(problem_shape),
|
||||
stride_dV);
|
||||
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
|
||||
select<0,3,4>(problem_shape),
|
||||
stride_dO);
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), make_shape(Q, D, HB), stride_Q);
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), make_shape(K, D, HB), stride_K);
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), make_shape(K, D_VO, HB), stride_V);
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), make_shape(Q, D_VO, HB), stride_O);
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), make_shape(Q, HB), stride_LSE);
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), make_shape(Q, D, HB), stride_dQ);
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), make_shape(K, D, HB), stride_dK);
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), make_shape(K, D_VO, HB), stride_dV);
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), make_shape(Q, D_VO, HB), stride_dO);
|
||||
|
||||
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
|
||||
|
||||
@ -549,6 +524,9 @@ struct BwdRunner {
|
||||
}
|
||||
|
||||
auto initialize_problem_shape(Options const& options) {
|
||||
int h_r = options.h / options.h_k;
|
||||
assert(options.h % options.h_k == 0);
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
int num_batches = options.b;
|
||||
|
||||
@ -599,14 +577,14 @@ struct BwdRunner {
|
||||
ProblemShape problem_shape{
|
||||
{max_seqlen_q, block_cumulative_seqlen_q.get(), total_seqlen_q},
|
||||
{max_seqlen_kv, block_cumulative_seqlen_kv.get(), total_seqlen_kv},
|
||||
options.d, options.d_vo, {options.h, options.b}
|
||||
options.d, options.d_vo, {{h_r, options.h_k}, options.b}
|
||||
};
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(options.h, 1));
|
||||
auto tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, options.d, options.d_vo, make_shape(make_shape(h_r, options.h_k), 1));
|
||||
|
||||
return cute::make_tuple(problem_shape, tensor_shape);
|
||||
}
|
||||
else {
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {options.h, options.b}};
|
||||
ProblemShape problem_shape{options.q, options.k, options.d, options.d_vo, {{h_r, options.h_k}, options.b}};
|
||||
return cute::make_tuple(problem_shape, problem_shape);
|
||||
}
|
||||
}
|
||||
@ -616,22 +594,23 @@ struct BwdRunner {
|
||||
auto [problem_shape, tensor_shape] = initialize_problem_shape(options);
|
||||
auto [Q, K, D, D_VO, HB] = tensor_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
|
||||
// for varlen, Q == total_Q, K == total_K, B = 1
|
||||
// but in problem_shape, they've got to be max_Q/max_K, and B = B
|
||||
|
||||
auto shape_Q = make_shape(Q, D, make_shape(H, B));
|
||||
auto shape_O = make_shape(Q, D_VO, make_shape(H, B));
|
||||
auto shape_K = make_shape(K, D, make_shape(H, B));
|
||||
auto shape_V = make_shape(K, D_VO, make_shape(H, B));
|
||||
auto shape_LSE = make_shape(Q, make_shape(H, B));
|
||||
auto shape_Q = make_shape(Q, D, HB);
|
||||
auto shape_K = make_shape(K, D, HB);
|
||||
auto shape_V = make_shape(K, D_VO, HB);
|
||||
auto shape_O = make_shape(Q, D_VO, HB);
|
||||
auto shape_LSE = make_shape(Q, HB);
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, B == 1 ? 0 : D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, B == 1 ? 0 : D*K*H));
|
||||
stride_V = make_stride(D_VO, _1{}, make_stride(D_VO*K, B == 1 ? 0 : D_VO*K*H));
|
||||
stride_O = make_stride(D_VO, _1{}, make_stride(D_VO*Q, B == 1 ? 0 : D_VO*Q*H));
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, B == 1 ? 0 : Q*H));
|
||||
stride_Q = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
|
||||
stride_K = make_stride(D, _1{}, make_stride(make_stride(_0{}, D*K), B == 1 ? 0 : D*K*H_K));
|
||||
stride_V = make_stride(D_VO, _1{}, make_stride(make_stride(_0{},D_VO*K), B == 1 ? 0 : D_VO*K*H_K));
|
||||
stride_O = make_stride(D_VO, _1{}, make_stride(make_stride(D_VO*Q, D_VO*Q*H_R), B == 1 ? 0 : D_VO*Q*H_R*H_K));
|
||||
stride_LSE = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
@ -642,20 +621,23 @@ struct BwdRunner {
|
||||
return size(make_shape(1ull, shape));
|
||||
};
|
||||
|
||||
auto size_K = lsize(K * D * H_K * B);
|
||||
auto size_V = lsize(K * D_VO * H_K * B);
|
||||
|
||||
block_Q.reset(lsize(shape_Q));
|
||||
block_K.reset(lsize(shape_K));
|
||||
block_V.reset(lsize(shape_V));
|
||||
block_K.reset(size_K);
|
||||
block_V.reset(size_V);
|
||||
block_O.reset(lsize(shape_O));
|
||||
block_LSE.reset(lsize(shape_LSE));
|
||||
|
||||
block_dQ.reset(lsize(shape_Q));
|
||||
block_dK.reset(lsize(shape_K));
|
||||
block_dV.reset(lsize(shape_V));
|
||||
block_dK.reset(size_K);
|
||||
block_dV.reset(size_V);
|
||||
block_dO.reset(lsize(shape_O));
|
||||
|
||||
block_ref_dQ.reset(lsize(shape_Q));
|
||||
block_ref_dK.reset(lsize(shape_K));
|
||||
block_ref_dV.reset(lsize(shape_V));
|
||||
block_ref_dK.reset(size_K);
|
||||
block_ref_dV.reset(size_V);
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
@ -689,7 +671,7 @@ struct BwdRunner {
|
||||
select<0,4>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
if (! options.skip_reference) {
|
||||
if (not options.skip_reference) {
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
}
|
||||
|
||||
@ -816,11 +798,12 @@ struct BwdRunner {
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask> ? 0.5 : 1.0);
|
||||
double flops = 2.0 * (std::is_same_v<ActiveMask, CausalForBackwardMask<false>> || std::is_same_v<ActiveMask, CausalForBackwardMask<true>> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= (3 * static_cast<double>(get<2>(problem_shape)) + 2 * static_cast<double>(get<3>(problem_shape)));
|
||||
flops *= static_cast<double>(get<4,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,0,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,0,1>(problem_shape));
|
||||
flops *= static_cast<double>(get<4,1>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
@ -1001,7 +984,7 @@ int main_single(int argc, char const **args) {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " D_VO " << options.d_vo << " ";
|
||||
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
|
||||
@ -80,6 +80,7 @@ struct Options {
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
bool is_fused_reduction = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
@ -139,9 +140,12 @@ struct Options {
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
|
||||
if (split_kv == 0) {
|
||||
split_kv = 1;
|
||||
}
|
||||
cmd.get_cmd_line_argument("page", page, defaults.page);
|
||||
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
|
||||
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
|
||||
is_var_split_kv = cmd.check_cmd_line_flag("var_split_kv");
|
||||
if (page == -1) {
|
||||
is_var_split_kv = false;
|
||||
}
|
||||
@ -149,6 +153,10 @@ struct Options {
|
||||
if (is_var_split_kv == true) {
|
||||
split_kv = max_split_kv;
|
||||
}
|
||||
is_fused_reduction = cmd.check_cmd_line_flag("fuse_reduction");
|
||||
if (split_kv == 1) {
|
||||
is_fused_reduction = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
@ -176,6 +184,8 @@ struct Options {
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --spread=<float> Relative spread away from K for paging\n"
|
||||
<< " --split_kv=<int> Split KV factor\n"
|
||||
<< " --fused_reduction Fuse the reduction operation\n"
|
||||
<< " --var_split_kv Use varying split KV factor\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
@ -514,7 +524,8 @@ struct Runner {
|
||||
stride_LSE},
|
||||
hw_info,
|
||||
options.split_kv,
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr,
|
||||
options.is_fused_reduction
|
||||
};
|
||||
if (options.split_kv < 0 && !options.is_var_split_kv) {
|
||||
Operation::set_split_kv(arguments);
|
||||
@ -724,13 +735,17 @@ void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
if (!options.is_fused_reduction || options.split_kv == 1) {
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
}
|
||||
#elif FP16
|
||||
name += " fp16";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
if (!options.is_fused_reduction || options.split_kv == 1) {
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -90,6 +90,7 @@ struct Options {
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool causal_q_begin = true;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
@ -231,6 +232,8 @@ struct Options {
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
std::string causal_type;
|
||||
cmd.get_cmd_line_argument<std::string>("causal-type", causal_type, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
@ -240,6 +243,11 @@ struct Options {
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
if(causal_type == "qend") {
|
||||
causal_q_begin = false;
|
||||
} else {
|
||||
causal_q_begin = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
@ -279,6 +287,7 @@ struct Options {
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --causal-type=<qbegin|qend> Causal mask type\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
@ -581,8 +590,8 @@ struct MlaFwdRunner {
|
||||
|
||||
ProblemShapeType problem_size_for_launch;
|
||||
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
|
||||
get<2>(problem_size_for_launch) = get<2>(problem_size);
|
||||
get<3>(problem_size_for_launch) = get<3>(problem_size);
|
||||
|
||||
@ -642,9 +651,9 @@ struct MlaFwdRunner {
|
||||
}
|
||||
|
||||
auto buffer_init_fn = [&](auto& buffer) {
|
||||
buffer.block_Q.reset(size(shape_Q), kIsVarlen ? D_latent_rope*SQ*H : 0);
|
||||
buffer.block_K.reset(size(shape_K), kIsVarlen ? D_latent_rope*SK*H_K : 0);
|
||||
buffer.block_V.reset(size(shape_V), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_Q.reset(size(shape_Q));
|
||||
buffer.block_K.reset(size(shape_K));
|
||||
buffer.block_V.reset(size(shape_V));
|
||||
buffer.block_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_LSE.reset(size(shape_LSE));
|
||||
buffer.block_ref_O.reset(size(shape_O), kIsVarlen ? D*SQ*H : 0);
|
||||
@ -840,7 +849,8 @@ struct MlaFwdRunner {
|
||||
flops *= static_cast<double>(size<3,1>(problem_shape));
|
||||
}
|
||||
|
||||
flops *= 2.0 * (std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
|
||||
flops *= 2.0 * (std::is_same_v<ActiveMask, CausalMask<false>> ||
|
||||
std::is_same_v<ActiveMask, CausalMask<true>> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(size<3,0>(problem_shape));
|
||||
|
||||
double flops0 = flops * static_cast<double>(size<2, 0>(problem_shape) + size<2, 1>(problem_shape));
|
||||
@ -1013,7 +1023,11 @@ int main_single(int argc, char const **args) {
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask<false>{});
|
||||
if(options.causal_q_begin) {
|
||||
fn(CausalMask{});
|
||||
} else {
|
||||
fn(CausalMask<false>{});
|
||||
}
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
|
||||
@ -59,6 +59,16 @@ set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2
|
||||
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
|
||||
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
|
||||
set(TEST_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
|
||||
|
||||
|
||||
|
||||
set(TEST_MLA_FWD_VARLEN_00 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_01 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
@ -75,6 +85,15 @@ set(TEST_MLA_FWD_VARLEN_11 --verify --varlen --mask=causal,residual --dl=128 --d
|
||||
set(TEST_MLA_FWD_VARLEN_12 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=177:766 --varlen-k=257:845)
|
||||
set(TEST_MLA_FWD_VARLEN_13 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=2 --varlen-q=177:0:479 --varlen-k=257:0:766)
|
||||
set(TEST_MLA_FWD_VARLEN_14 --verify --varlen --mask=causal,residual --dl=128 --dr=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
set(TEST_MLA_FWD_VARLEN_15 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_16 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_MLA_FWD_VARLEN_17 --verify --varlen --mask=causal --causal-type=qbegin --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_MLA_FWD_VARLEN_18 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_MLA_FWD_VARLEN_19 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=257)
|
||||
set(TEST_MLA_FWD_VARLEN_20 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=17 --varlen-k=25)
|
||||
set(TEST_MLA_FWD_VARLEN_21 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1013 --varlen-k=1024)
|
||||
set(TEST_MLA_FWD_VARLEN_22 --verify --varlen --mask=causal --causal-type=qend --d=128 --h=4 --h_k=4 --varlen-q=1024 --varlen-k=1035)
|
||||
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
@ -87,6 +106,9 @@ set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
|
||||
set(TEST_BWD_MLA_BASIC --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=no)
|
||||
set(TEST_BWD_MLA_VARLEN --b=1 --h=4 --q=512 --k=512 --d=192 --d_vo=128 --verify --mask=residual --varlen)
|
||||
|
||||
set(TEST_MLA_SEP_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --verify)
|
||||
set(TEST_MLA_FUSE_REDUCTION --b=1 --k=4096 --split_kv=8 --page=128 --fuse_reduction --verify)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
foreach(PREC fp8 fp16)
|
||||
@ -116,6 +138,14 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_VARLEN_12
|
||||
TEST_VARLEN_13
|
||||
TEST_VARLEN_14
|
||||
TEST_VARLEN_15
|
||||
TEST_VARLEN_16
|
||||
TEST_VARLEN_17
|
||||
TEST_VARLEN_18
|
||||
TEST_VARLEN_19
|
||||
TEST_VARLEN_20
|
||||
TEST_VARLEN_21
|
||||
TEST_VARLEN_22
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -139,6 +169,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
TEST_MLA_SEP_REDUCTION
|
||||
TEST_MLA_FUSE_REDUCTION
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -149,6 +181,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
TEST_MLA_SEP_REDUCTION
|
||||
TEST_MLA_FUSE_REDUCTION
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
@ -207,6 +241,14 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
TEST_MLA_FWD_VARLEN_12
|
||||
TEST_MLA_FWD_VARLEN_13
|
||||
TEST_MLA_FWD_VARLEN_14
|
||||
TEST_MLA_FWD_VARLEN_15
|
||||
TEST_MLA_FWD_VARLEN_16
|
||||
TEST_MLA_FWD_VARLEN_17
|
||||
TEST_MLA_FWD_VARLEN_18
|
||||
TEST_MLA_FWD_VARLEN_19
|
||||
TEST_MLA_FWD_VARLEN_20
|
||||
TEST_MLA_FWD_VARLEN_21
|
||||
TEST_MLA_FWD_VARLEN_22
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_fwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_fwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
@ -8,7 +8,7 @@ For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit
|
||||
|
||||
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
|
||||
|
||||
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
|
||||
For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
|
||||
The kernel and collective layer are then formulated to be fmha-specific.
|
||||
@ -67,6 +67,8 @@ For detailed information on how to invoke them, check out either the tests in `C
|
||||
to simplify the sample, clarified that `fmha_gen` sample only supports head
|
||||
dim 128.
|
||||
|
||||
* 4.3.0: For variable sequence length, the code requires a batch of valid (but never used) padding memory ahead of the first output batch. No padding is needed for the input tensor, but it requires that the input tensor contain no NaN or Inf values. Note that users should set `total_length` to the `problem_shape`.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
@ -203,13 +203,12 @@ struct CausalMask : NoMask {
|
||||
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
if constexpr (IsQBegin) {
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
} else {
|
||||
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
@ -222,12 +221,12 @@ struct CausalMask : NoMask {
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
if constexpr (IsQBegin) {
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
} else {
|
||||
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
|
||||
return ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape));
|
||||
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
|
||||
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,9 +276,10 @@ struct CausalMask : NoMask {
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
|
||||
template<bool kIsQBegin = true>
|
||||
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
|
||||
|
||||
using Base = CausalMask<true>;
|
||||
using Base = CausalMask<kIsQBegin>;
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
@ -296,10 +296,15 @@ struct CausalForBackwardMask : CausalMask<true>, ResidualMaskForBackward {
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
int offset_q = 0;
|
||||
if constexpr (!kIsQBegin) {
|
||||
offset_q = get<1>(problem_size) - get<0>(problem_size);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
bool masked = (get<0>(pos) < get<1>(pos)) || !elem_less(pos, problem_size);
|
||||
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
|
||||
if (masked) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
|
||||
@ -534,14 +534,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
|
||||
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
// Each thread owns a single row
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||
|
||||
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
|
||||
|
||||
@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
auto problem_shape_qk = problem_shape;
|
||||
|
||||
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
|
||||
IntProblemShape problem_shape_qk;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dQ) = get<0>(dQ);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_Q -= max_length_q * get<0>(dQ);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
|
||||
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_kv != nullptr) {
|
||||
int max_length_kv = get<1>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dK) = get<0>(dK);
|
||||
get<2,1>(dV) = get<0>(dV);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_K -= max_length_kv * get<0>(dK);
|
||||
ptr_V -= max_length_kv * get<0>(dV);
|
||||
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
|
||||
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
|
||||
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
|
||||
get<2>(problem_shape_qk) = get<2>(problem_shape);
|
||||
get<3>(problem_shape_qk) = get<3>(problem_shape);
|
||||
}
|
||||
} else {
|
||||
problem_shape_qk = problem_shape;
|
||||
}
|
||||
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
|
||||
@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
int q_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
q_offs_0 = max_length_q - get<0>(problem_shape);
|
||||
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
|
||||
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
int kv_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
int max_length = get<1>(params_problem_shape).max_length;
|
||||
kv_offs_0 = max_length - get<1>(problem_shape);
|
||||
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
|
||||
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
|
||||
@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
|
||||
|
||||
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
|
||||
IntProblemShape problem_shape_qk;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dQ) = get<0>(dQ);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_Q -= max_length_q * get<0>(dQ);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
|
||||
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_kv != nullptr) {
|
||||
int max_length_kv = get<1>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dK) = get<0>(dK);
|
||||
get<2,1>(dV) = get<0>(dV);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_K -= max_length_kv * get<0>(dK);
|
||||
ptr_V -= max_length_kv * get<0>(dV);
|
||||
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
|
||||
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
|
||||
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
|
||||
get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
|
||||
get<3>(problem_shape_qk) = get<3>(problem_shape);
|
||||
}
|
||||
} else {
|
||||
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
|
||||
}
|
||||
|
||||
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
|
||||
@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
int q_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
q_offs_0 = max_length_q - get<0>(problem_shape);
|
||||
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
|
||||
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
int kv_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
int max_length = get<1>(params_problem_shape).max_length;
|
||||
kv_offs_0 = max_length - get<1>(problem_shape);
|
||||
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
|
||||
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
|
||||
@ -60,6 +60,38 @@ template<
|
||||
class Mask
|
||||
>
|
||||
class Sm100FmhaBwd {
|
||||
private:
|
||||
template <typename T>
|
||||
constexpr static auto to_bwd_shape(T shape) {
|
||||
if constexpr (IsMla) { // remove GQA mode
|
||||
constexpr int R = decltype(rank(shape))::value;
|
||||
auto HB = get<R-1>(shape);
|
||||
auto rest = take<0,R-1>(shape);
|
||||
return append(rest, make_shape(size<0>(HB), get<1>(HB)));
|
||||
}
|
||||
else {
|
||||
return shape;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr static auto to_bwd_stride(T stride) {
|
||||
if constexpr (IsMla) { // remove GQA mode
|
||||
constexpr int R = decltype(rank(stride))::value;
|
||||
auto HB = get<R-1>(stride);
|
||||
auto rest = take<0,R-1>(stride);
|
||||
if constexpr (is_same_v<remove_cv_t<decltype(get<0,0>(HB))>, _0>) {
|
||||
return append(rest, make_stride(get<0,1>(HB), get<1>(HB)));
|
||||
}
|
||||
else {
|
||||
return append(rest, make_stride(get<0,0>(HB), get<1>(HB)));
|
||||
}
|
||||
}
|
||||
else {
|
||||
return stride;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
@ -67,26 +99,26 @@ public:
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int,int>, int>> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<cute::_0,int>, int>> stride_dV;
|
||||
|
||||
ElementAccumulator softmax_scale;
|
||||
|
||||
@ -106,9 +138,10 @@ public:
|
||||
>
|
||||
>;
|
||||
|
||||
using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{}));
|
||||
using OperationMla = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized<
|
||||
ProblemShape, Element, ElementAccumulator, TileShape, Mask
|
||||
ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask
|
||||
>
|
||||
>;
|
||||
|
||||
@ -134,10 +167,11 @@ private:
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K));
|
||||
auto log2_e = log2f(expf(1.0f));
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_shape,
|
||||
@ -154,14 +188,15 @@ private:
|
||||
using namespace cute;
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K));
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_shape,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
@ -171,22 +206,22 @@ private:
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> const& stride_dQ = {}) {
|
||||
|
||||
return typename Operation::Arguments{
|
||||
args.problem_shape,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ,
|
||||
to_bwd_shape(args.problem_shape),
|
||||
{ args.ptr_Q, to_bwd_stride(args.stride_Q),
|
||||
args.ptr_K, to_bwd_stride(args.stride_K),
|
||||
args.ptr_V, to_bwd_stride(args.stride_V),
|
||||
args.ptr_dO, to_bwd_stride(args.stride_dO),
|
||||
scaled_lse, to_bwd_stride(stride_scaled_lse),
|
||||
sum_OdO, to_bwd_stride(stride_sum_OdO),
|
||||
dQ_acc, to_bwd_stride(stride_dQ),
|
||||
args.softmax_scale },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
{ args.ptr_dK, to_bwd_stride(args.stride_dK),
|
||||
args.ptr_dV, to_bwd_stride(args.stride_dV) },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
@ -220,7 +255,7 @@ public:
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
@ -240,7 +275,7 @@ public:
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
@ -269,7 +304,7 @@ public:
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
auto [H, B] = product_each(HB);
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
|
||||
@ -127,7 +127,11 @@ public:
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
if (args.is_fused_reduction && split_wave_aware > 1) {
|
||||
args.split_kv = std::min(split_wave_aware, static_cast<int>(sm_count/2));
|
||||
} else {
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
@ -273,11 +277,33 @@ public:
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
auto [H, K, D, B] = params.fmha_params.problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
|
||||
auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int);
|
||||
uint8_t* ws = reinterpret_cast<uint8_t*>(params.fmha_params.epilogue.ptr_lse_exchange_buff);
|
||||
result = cudaMemsetAsync(ws, 0, total_bytes, stream);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;;
|
||||
}
|
||||
}
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
@ -298,7 +324,7 @@ public:
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
|
||||
@ -46,18 +46,18 @@ struct FmhaKernelBwdConvert {
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const ElementAcc* ptr_src_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
||||
tuple<int, _1, tuple<tuple<int, int>, int>> stride_src_dQ;
|
||||
const ElementAcc* ptr_src_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dK;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dK;
|
||||
const ElementAcc* ptr_src_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dV;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_src_dV;
|
||||
|
||||
Element* ptr_dest_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
|
||||
tuple<int, _1, tuple<tuple<int, int>, int>> stride_dest_dQ;
|
||||
Element* ptr_dest_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dK;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dK;
|
||||
Element* ptr_dest_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dV;
|
||||
tuple<int, _1, tuple<tuple<_0, int>, int>> stride_dest_dV;
|
||||
|
||||
ElementAcc scale = 1.0;
|
||||
};
|
||||
@ -104,8 +104,8 @@ struct FmhaKernelBwdConvert {
|
||||
|
||||
template<class StrideSrc, class StrideDest, class Count>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) {
|
||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
|
||||
int seqlen = count;
|
||||
if constexpr (is_variable_length_v<decltype(count)>) {
|
||||
|
||||
@ -46,18 +46,18 @@ struct FmhaKernelBwdSumOdO {
|
||||
ProblemShape problem_shape;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_O;
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_dO;
|
||||
|
||||
ElementAcc* ptr_sum_OdO;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_sum_OdO;
|
||||
|
||||
const ElementAcc* ptr_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_lse;
|
||||
|
||||
ElementAcc* ptr_scaled_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
|
||||
cute::tuple<cute::_1, cute::tuple<cute::tuple<int, int>, int>> stride_scaled_lse;
|
||||
|
||||
ElementAcc sum_odo_scale = 1.0;
|
||||
ElementAcc lse_scale = 1.0;
|
||||
@ -104,11 +104,11 @@ struct FmhaKernelBwdSumOdO {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
|
||||
auto problem_q = get<0>(params.problem_shape);
|
||||
int seqlen_q = problem_q;
|
||||
|
||||
@ -119,13 +119,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static constexpr int Alignment = 128 / sizeof_bits_v<Element>;
|
||||
static constexpr int kStages = 2;
|
||||
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<Stride<int,int>, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<Stride<int,int>, int>>;
|
||||
using TensorStrideContiguousK_GQA = Stride<int, _1, Stride<Stride<_0,int>, int>>;
|
||||
using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride<Stride<_0,int>, int>>;
|
||||
|
||||
// compute S
|
||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
Element, TensorStrideContiguousK_GQA, Alignment,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
ElementAcc,
|
||||
Shape<TileShapeK, TileShapeQ, TileShapeDQK>,
|
||||
@ -137,7 +139,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
// compute dP
|
||||
using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
Element, TensorStrideContiguousK_GQA, Alignment,
|
||||
Element, TensorStrideContiguousK, Alignment,
|
||||
ElementAcc,
|
||||
Shape<TileShapeK, TileShapeQ, TileShapeDVO>,
|
||||
@ -177,7 +179,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
// somewhat arbitrary since we dump to smem, need to agree with the previous one
|
||||
Element, TensorStrideContiguousMN, Alignment,
|
||||
Element, TensorStrideContiguousMN, Alignment,
|
||||
Element, TensorStrideContiguousMN_GQA, Alignment,
|
||||
ElementAcc,
|
||||
Shape<TileShapeQ, TileShapeDQK, TileShapeK>,
|
||||
ClusterShape, cutlass::gemm::collective::StageCount<kStages>,
|
||||
@ -278,15 +280,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
|
||||
|
||||
using TensorStride = TensorStrideContiguousK; // S D (H B)
|
||||
using RowTensorStride = Stride<_1, Stride<int, int>>; // S (H B)
|
||||
using TensorStride_GQA = TensorStrideContiguousK_GQA;
|
||||
using RowTensorStride = Stride<_1, Stride<Stride<int, int>, int>>; // S (H B)
|
||||
|
||||
struct MainloopArguments {
|
||||
const Element* ptr_q;
|
||||
TensorStride stride_q;
|
||||
const Element* ptr_k;
|
||||
TensorStride stride_k;
|
||||
TensorStride_GQA stride_k;
|
||||
const Element* ptr_v;
|
||||
TensorStride stride_v;
|
||||
TensorStride_GQA stride_v;
|
||||
const Element* ptr_do;
|
||||
TensorStride stride_do;
|
||||
|
||||
@ -308,7 +311,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B;
|
||||
|
||||
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
|
||||
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}),
|
||||
make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}),
|
||||
SmemLayoutDQ{}(_, _, _0{})
|
||||
));
|
||||
|
||||
@ -322,9 +325,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
struct EpilogueArguments {
|
||||
Element* ptr_dk;
|
||||
TensorStride stride_dk;
|
||||
TensorStride_GQA stride_dk;
|
||||
Element* ptr_dv;
|
||||
TensorStride stride_dv;
|
||||
TensorStride_GQA stride_dv;
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
@ -346,7 +349,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static bool can_implement(Arguments const& args) {
|
||||
auto [Q, K, D, D_VO, HB] = args.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) {
|
||||
auto [H_R, H_K] = H;
|
||||
if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (D % Alignment != 0 || D_VO % Alignment != 0) {
|
||||
@ -432,7 +436,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
MainloopParams const& mainloop_params,
|
||||
@ -447,6 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) {
|
||||
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
int iter_index = iter_start;
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
@ -590,6 +596,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_index += 1;
|
||||
|
||||
while (iter_count > 0) {
|
||||
if (iter_index == iter_end) {
|
||||
iter_index = iter_start;
|
||||
get<0,0>(blk_coord_batch) += 1;
|
||||
}
|
||||
|
||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
|
||||
@ -660,7 +671,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
CUTLASS_DEVICE void mma(
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
TensorStorage& shared_tensors,
|
||||
@ -1119,7 +1131,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
BlkCoord const& blk_coord,
|
||||
BlkOffset const& blk_offset,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueArguments const& epilogue_args,
|
||||
@ -1141,6 +1154,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
int iter_index = iter_start;
|
||||
|
||||
// in tmem, S & P overlap
|
||||
// and dP and dQ overlap
|
||||
@ -1201,6 +1215,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tTR_cST_p = thread_t2r.partition_D(cST);
|
||||
Tensor tTR_cST = split_wg(tTR_cST_p);
|
||||
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
||||
// Tensor tTR_tST_p = thread_t2r.partition_S(tSTtST);
|
||||
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
||||
|
||||
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
||||
@ -1224,8 +1239,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
auto tRT_cST_p = thread_r2t.partition_S(tDVcST);
|
||||
auto tRT_cST = split_wg(tRT_cST_p);
|
||||
|
||||
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape);
|
||||
int last_iter = iter_count - 1 + iter_index;
|
||||
bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (iter_count > 0) {
|
||||
@ -1245,13 +1259,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
};
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == iter_start);
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
int kv_left = get<1>(blk_coord) * TileShapeK{};
|
||||
int kv_right = kv_left + TileShapeK{} - 1;
|
||||
int q_left = iter_index * TileShapeQ{} + offset;
|
||||
int q_right = q_left + TileShapeQ{} - 1;
|
||||
|
||||
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
||||
trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k);
|
||||
trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k);
|
||||
}
|
||||
|
||||
dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) {
|
||||
@ -1372,6 +1393,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
iter_count -= 1;
|
||||
iter_index += 1;
|
||||
if (iter_index == iter_end) {
|
||||
iter_index = iter_start;
|
||||
}
|
||||
}
|
||||
|
||||
epilogue(
|
||||
@ -1384,7 +1408,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
CUTLASS_DEVICE void reduce(
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape_ const& problem_shape,
|
||||
int iter_index,
|
||||
int iter_start,
|
||||
int iter_end,
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
MainloopParams const& mainloop_params,
|
||||
@ -1397,6 +1422,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
using X = Underscore;
|
||||
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
int iter_index = iter_start;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord;
|
||||
|
||||
@ -1408,7 +1434,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
||||
(_, _, _, _0{}, blk_coord_batch);
|
||||
(_, _, _, _0{}, _);
|
||||
|
||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||
|
||||
@ -1419,7 +1445,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
|
||||
|
||||
Tensor tTR_cDQ = thread_t2r.partition_D(cDQ);
|
||||
Tensor tTR_gDQ = thread_t2r.partition_D(gDQ);
|
||||
Tensor tTR_sDQ = thread_t2r.partition_D(sDQ);
|
||||
Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ);
|
||||
|
||||
@ -1465,7 +1490,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
).arrive_and_wait();
|
||||
if (lane_predicate) {
|
||||
// launch tma store
|
||||
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index));
|
||||
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch));
|
||||
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
|
||||
}
|
||||
|
||||
@ -1474,11 +1499,18 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
iter_count -= 1;
|
||||
iter_index += 1;
|
||||
if (iter_index == iter_end) {
|
||||
iter_index = iter_start;
|
||||
get<0,0>(blk_coord_batch) += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
auto role = warp_idx_to_role(warp_idx);
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
@ -1676,21 +1708,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
pipeline_init_wait(size(ClusterShape{}));
|
||||
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z));
|
||||
auto [problem_shape, blk_offset] = apply_variable_length_offset(
|
||||
params.problem_shape,
|
||||
blk_coord
|
||||
);
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
|
||||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
|
||||
}
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
return;
|
||||
}
|
||||
iter_count -= iter_start;
|
||||
int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape);
|
||||
|
||||
if (iter_count <= 0) {
|
||||
epilogue_clear(
|
||||
@ -1711,6 +1745,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
@ -1732,6 +1767,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
@ -1754,6 +1790,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_offset,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.epilogue,
|
||||
@ -1785,6 +1822,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
iter_start,
|
||||
iter_end,
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
@ -1801,6 +1839,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
/* no-op */
|
||||
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
@ -1811,7 +1850,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
auto [Q, K, D, D_VO, HB] = params.problem_shape;
|
||||
auto [H, B] = HB;
|
||||
dim3 grid(ceil_div(K, TileShapeK{}), H, B);
|
||||
auto [H_R, H_K] = H;
|
||||
dim3 grid(ceil_div(K, TileShapeK{}), H_K, B);
|
||||
return grid;
|
||||
}
|
||||
};
|
||||
|
||||
@ -1230,9 +1230,16 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
|
||||
};
|
||||
|
||||
bool leading_causal_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord));
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
int kv_left = get<1>(blk_coord) * TileShapeK{};
|
||||
int kv_right = kv_left + TileShapeK{} - 1;
|
||||
int q_left = iter_index * TileShapeQ{} + offset;
|
||||
int q_right = q_left + TileShapeQ{} - 1;
|
||||
|
||||
leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left)));
|
||||
}
|
||||
bool trailing_residual_masking = false;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::ResidualMaskForBackward, Mask>) {
|
||||
@ -1473,6 +1480,9 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
|
||||
|
||||
|
||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
auto role = warp_idx_to_role(warp_idx);
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
@ -1677,9 +1687,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
|
||||
);
|
||||
int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{});
|
||||
int iter_start = 0;
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>
|
||||
|| std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
|
||||
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
|
||||
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
|
||||
int offset = get<1>(problem_shape) - get<0>(problem_shape);
|
||||
iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{});
|
||||
}
|
||||
if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) {
|
||||
return;
|
||||
@ -1795,6 +1807,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
|
||||
/* no-op */
|
||||
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
|
||||
@ -251,6 +251,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
@ -465,6 +468,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
else if (role == WarpRole::Correction) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
|
||||
|
||||
bool has_valid = false;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
@ -476,6 +481,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
has_valid = true;
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
mainloop.correction_empty(
|
||||
blk_coord,
|
||||
@ -505,16 +512,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
static_assert(NumWarpsCorrection == 1);
|
||||
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
if (has_valid) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::MMA) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
bool allocated = false;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
@ -527,6 +535,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!allocated) {
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
allocated = true;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
@ -580,6 +594,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
else if (role == WarpRole::Epilogue) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
bool has_valid = false;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
@ -591,6 +607,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
has_valid = true;
|
||||
|
||||
epilogue.store(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.epilogue, params.problem_shape,
|
||||
@ -602,8 +620,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
|
||||
static_assert(NumWarpsEpilogue <= 1);
|
||||
if constexpr (NumWarpsEpilogue == 1) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
if(has_valid) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -612,6 +632,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
|
||||
/* no-op, donate regs and exit */
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
@ -247,6 +247,9 @@ struct Sm100FmhaGenKernelWarpspecialized {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
@ -365,7 +368,7 @@ struct Sm100FmhaGenKernelWarpspecialized {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
|
||||
pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp);
|
||||
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
|
||||
shared_storage.pipelines.corr_epi,
|
||||
pipeline_corr_epi_params,
|
||||
@ -569,6 +572,7 @@ struct Sm100FmhaGenKernelWarpspecialized {
|
||||
|
||||
/* no-op, donate regs and exit */
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
@ -101,8 +101,17 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
if (local_split_kv == 1) return;
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
@ -113,12 +122,6 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
@ -130,17 +133,18 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
lse_max = fmax(local_lse[i], lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
lse_max = fmax(__shfl_xor_sync(0xffffffff, lse_max, offset), lse_max);
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/simd_sm100.hpp"
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
@ -44,6 +45,7 @@
|
||||
|
||||
#include "gather_tensor.hpp" // from examples/common
|
||||
#include "common/pow_2.hpp"
|
||||
#include "sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
@ -87,8 +89,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
using TileShapeR = tuple_element_t<1, TileShapeD>;
|
||||
static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim");
|
||||
|
||||
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
|
||||
using TensorStride = Stride<int64_t, _1, int64_t>;
|
||||
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
|
||||
using TensorStride = Stride<int64_t, _1, int64_t>;
|
||||
using TmemAllocator = cute::conditional_t<kIs2Sm, cute::TMEM::Allocator2Sm, cute::TMEM::Allocator1Sm>;
|
||||
|
||||
static_assert(TileShapeH{} == 128);
|
||||
@ -181,10 +183,13 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB;
|
||||
using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB;
|
||||
using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int<IterationsPV_K>{}, _2{})));
|
||||
using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{}));
|
||||
using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}));
|
||||
|
||||
static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v<Element>);
|
||||
static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v<Element>);
|
||||
|
||||
// pre-condition for overlapped smem staging
|
||||
static_assert(kBytesLoadKC == kBytesLoadVC);
|
||||
static_assert(StagesQK == StagesPV);
|
||||
@ -226,7 +231,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKC>> smem_kc;
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutVC>> smem_vc;
|
||||
};
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
union {
|
||||
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
alignas(2048) cute::array<ElementOut, size(TileShapeAcc{})> smem_acc;
|
||||
};
|
||||
};
|
||||
|
||||
struct SharedStorage {
|
||||
@ -280,6 +288,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
KernelHardwareInfo hw_info;
|
||||
int split_kv = -1;
|
||||
int* ptr_split_kv = nullptr;
|
||||
bool is_fused_reduction = false;
|
||||
};
|
||||
|
||||
using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A;
|
||||
@ -288,6 +297,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
using GmemLayout = decltype(make_layout(Shape<int,int,int>{}, Stride<int64_t, _1, int64_t>{}));
|
||||
using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{}));
|
||||
|
||||
using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{},
|
||||
make_tensor(recast_ptr<ElementOut>(nullptr), GmemLayout{}), SmemLayout{}));
|
||||
|
||||
struct MainloopParams {
|
||||
TmaLoadQLatent tma_load_q_latent;
|
||||
TmaLoadQRope tma_load_q_rope;
|
||||
@ -306,6 +321,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
Stride<_1, int> stride_lse;
|
||||
Stride<_1, int> stride_lse_acc;
|
||||
ElementAcc output_scale = 1.0f;
|
||||
ElementLSE* ptr_lse_exchange_buff = nullptr;
|
||||
int* ptr_lse_max_exchange_buff = nullptr;
|
||||
int* ptr_lock = nullptr; // semaphore
|
||||
TmaReduceSum tma_reduce_sum;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
@ -316,6 +335,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
int split_kv = -1;
|
||||
int* ptr_split_kv = nullptr;
|
||||
bool is_fused_reduction = false;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
@ -380,11 +400,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
epilogue_params.ptr_o = args.epilogue.ptr_o;
|
||||
epilogue_params.stride_o = args.epilogue.stride_o;
|
||||
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
|
||||
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
|
||||
epilogue_params.stride_lse = args.epilogue.stride_lse;
|
||||
epilogue_params.output_scale = args.epilogue.output_scale;
|
||||
epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr<ElementOut>(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{});
|
||||
|
||||
if (args.split_kv > 1) {
|
||||
if (!args.is_fused_reduction && args.split_kv > 1) {
|
||||
ElementAcc* ptr_o_acc = reinterpret_cast<ElementAcc*>(workspace);
|
||||
ElementLSE* ptr_lse_acc = reinterpret_cast<ElementLSE*>(ptr_o_acc + H * L * args.split_kv * B);
|
||||
epilogue_params.ptr_o_acc = ptr_o_acc;
|
||||
@ -392,10 +413,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
epilogue_params.stride_o_acc = make_tuple(static_cast<int64_t>(0 + L) * args.split_kv, _1{}, static_cast<int64_t>(0 + H * L) * args.split_kv);
|
||||
epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv);
|
||||
} else if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
ElementLSE* ptr_lse_exchange_buff = reinterpret_cast<ElementLSE*>(workspace);
|
||||
epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff;
|
||||
int* ptr_lse_max_exchange_buff = reinterpret_cast<int*>(ptr_lse_exchange_buff + H * B);
|
||||
epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff;
|
||||
int* ptr_lock = ptr_lse_max_exchange_buff + H * B;
|
||||
epilogue_params.ptr_lock = ptr_lock;
|
||||
}
|
||||
|
||||
return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params,
|
||||
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv};
|
||||
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv),
|
||||
args.split_kv, args.ptr_split_kv, args.is_fused_reduction};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
@ -403,10 +432,29 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
auto split_kv = args.split_kv;
|
||||
return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
|
||||
size_t workspace_size {0};
|
||||
if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
// one exchange buffer for LSE max and another buffer for total LSE
|
||||
// two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127
|
||||
workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int);
|
||||
} else if (!args.is_fused_reduction && args.split_kv > 1) {
|
||||
workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
|
||||
}
|
||||
return workspace_size;
|
||||
}
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
Arguments const& args, void* ws, cudaStream_t stream) {
|
||||
auto workspace_size = get_workspace_size(args);
|
||||
if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
auto result = cudaMemsetAsync(ws, 0, workspace_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemsetAsync() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;;
|
||||
}
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -448,11 +496,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n";
|
||||
return false;
|
||||
}
|
||||
if (args.is_fused_reduction && args.split_kv > 1) {
|
||||
if (2 * args.split_kv > args.hw_info.sm_count ||
|
||||
std::is_same_v<TileScheduler, Sm100MlaIndividualTileScheduler>) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) {
|
||||
#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED))
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
TileScheduler tile_scheduler(params.tile_scheduler);
|
||||
|
||||
@ -746,7 +803,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||
local_split_kv
|
||||
local_split_kv,
|
||||
params.is_fused_reduction
|
||||
);
|
||||
}
|
||||
|
||||
@ -759,6 +817,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class BlkCoord>
|
||||
@ -1777,7 +1836,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
if (epilogue_args.ptr_o_acc != nullptr) {
|
||||
|
||||
if (split_kv > 1) {
|
||||
using ElementOutAcc = ElementAcc;
|
||||
constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v<ElementOutAcc>;
|
||||
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc);
|
||||
@ -1806,16 +1866,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
if (get<1>(cta_coord) == 0) {
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
@ -1845,24 +1909,165 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
if (get<1>(cta_coord) == 0) {
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
// store LSE
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
// for 2x2 dp, this must be conditional and the index is wrong
|
||||
if (! kIs2Sm || (threadIdx.x < 64))
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE ElementLSE epilogue_lse_reduction(
|
||||
ElementAcc& row_max,
|
||||
ElementAcc& row_sum,
|
||||
BlkCoord const& cta_coord,
|
||||
ProblemShape const& problem_shape,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueParams const& epilogue_args,
|
||||
int const& local_split_kv) {
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
|
||||
|
||||
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
|
||||
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
|
||||
|
||||
auto wait = [](int* lock, int count) {
|
||||
__threadfence();
|
||||
if (threadIdx.x == 0) {
|
||||
atomicAdd(lock, 1);
|
||||
while (atomicCAS(lock, count, count) != count) {};
|
||||
}
|
||||
__threadfence();
|
||||
Sync::sync();
|
||||
};
|
||||
|
||||
const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord);
|
||||
|
||||
if (! kIs2Sm || (threadIdx.x < 64)) {
|
||||
atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse));
|
||||
}
|
||||
wait(local_lock, local_split_kv);
|
||||
|
||||
auto global_lse_max = static_cast<ElementLSE>(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x));
|
||||
|
||||
Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
if (! kIs2Sm || (threadIdx.x < 64)) {
|
||||
atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max));
|
||||
}
|
||||
wait(local_lock, 2*local_split_kv);
|
||||
|
||||
const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x);
|
||||
const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementLSE>::infinity() :
|
||||
cutlass::fast_log(sum_lse) + global_lse_max;
|
||||
const auto lse_scale = expf(lse - global_lse);
|
||||
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
|
||||
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
|
||||
|
||||
// write out the global LSE
|
||||
if (! kIs2Sm || (threadIdx.x < 64)) {
|
||||
gLSE(threadIdx.x) = global_lse;
|
||||
}
|
||||
}
|
||||
return lse_scale;
|
||||
}
|
||||
|
||||
|
||||
template<class BlkCoord>
|
||||
CUTLASS_DEVICE void epilogue_reduction(
|
||||
ElementAcc& row_max,
|
||||
ElementAcc& row_sum,
|
||||
BlkCoord const& blk_coord,
|
||||
ProblemShape const& problem_shape,
|
||||
MainloopArguments const& mainloop_args,
|
||||
EpilogueParams const& epilogue_args,
|
||||
TensorStorage& shared_tensors,
|
||||
int const& local_split_kv,
|
||||
ElementLSE const& lse_scale) {
|
||||
|
||||
constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp;
|
||||
using Sync = cutlass::detail::NamedBarrierSync<kNumThreads, kNamedBarrierExchange>;
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
|
||||
|
||||
TiledMmaPV tiled_mma_pv;
|
||||
Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{})));
|
||||
|
||||
CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{});
|
||||
CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{});
|
||||
|
||||
using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination<ElementOut, 1, ElementAcc, ElementAcc, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
|
||||
EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int k = 0; k < IterationsPV_N; ++k) {
|
||||
auto cta_coord = replace<1>(blk_coord, k);
|
||||
|
||||
uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO);
|
||||
tOtO.data() = tmem_o;
|
||||
|
||||
Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{});
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
|
||||
Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord));
|
||||
|
||||
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
|
||||
auto thread_idx = threadIdx.x % size(tiled_t2r);
|
||||
|
||||
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
|
||||
Tensor tTR_gO = thread_t2r.partition_D(gO);
|
||||
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
|
||||
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
|
||||
|
||||
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast<ElementOut*>(shared_tensors.smem_acc.begin())), SmemLayout{});
|
||||
Tensor tTR_sO = thread_t2r.partition_D(sO);
|
||||
|
||||
Sync::sync();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tTR_rAcc); i++) {
|
||||
tTR_sO(i) = epilogue_op(tTR_rAcc(i));
|
||||
}
|
||||
tma_store_fence();
|
||||
Sync::sync();
|
||||
|
||||
auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{});
|
||||
auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO));
|
||||
auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord));
|
||||
if (threadIdx.x % kNumThreads == 0) {
|
||||
copy(epilogue_args.tma_reduce_sum,
|
||||
tma_reduce_sum_per_cta.partition_S(sO),
|
||||
tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta));
|
||||
tma_store_arrive();
|
||||
}
|
||||
tma_store_wait<0>();
|
||||
}
|
||||
}
|
||||
|
||||
template<class CtaCoord>
|
||||
CUTLASS_DEVICE void compute(
|
||||
@ -1877,7 +2082,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
typename PipelineP::PipelineState& pipeline_p_mma_producer_state,
|
||||
PipelineO& pipeline_mma_o,
|
||||
typename PipelineO::PipelineState& pipeline_mma_o_consumer_state,
|
||||
int const& split_kv) {
|
||||
int const& split_kv,
|
||||
bool const& is_fused_reduction) {
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
|
||||
@ -1987,17 +2193,38 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
|
||||
|
||||
// epilogue
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IterationsPV_N; j++) {
|
||||
epilogue(
|
||||
row_max, row_sum,
|
||||
replace<1>(cta_coord, j), problem_shape,
|
||||
mainloop_args, epilogue_args, shared_tensors,
|
||||
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv
|
||||
const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
if (!is_fused_reduction || actual_split_kv == 1) {
|
||||
// epilogue
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IterationsPV_N; j++) {
|
||||
epilogue(
|
||||
row_max, row_sum,
|
||||
replace<1>(cta_coord, j), problem_shape,
|
||||
mainloop_args, epilogue_args, shared_tensors,
|
||||
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO),
|
||||
actual_split_kv
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const ElementLSE lse_scale =
|
||||
epilogue_lse_reduction(
|
||||
row_max, row_sum,
|
||||
cta_coord,
|
||||
problem_shape,
|
||||
mainloop_args, epilogue_args,
|
||||
actual_split_kv
|
||||
);
|
||||
|
||||
epilogue_reduction(row_max, row_sum,
|
||||
cta_coord,
|
||||
problem_shape,
|
||||
mainloop_args, epilogue_args,
|
||||
shared_tensors,
|
||||
actual_split_kv,
|
||||
lse_scale
|
||||
);
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state);
|
||||
++pipeline_mma_o_consumer_state;
|
||||
|
||||
@ -142,8 +142,8 @@ struct Sm100MlaPersistentTileScheduler {
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
|
||||
@ -56,12 +56,12 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
@ -79,9 +79,9 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in);
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
ElementAcc acc_qk = 0;
|
||||
ElementAcc acc_dov = 0;
|
||||
ElementAcc acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
@ -94,20 +94,22 @@ void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_K] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
mS[idx_K] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
} // for idx_K
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
acc += mS[idx_K] * ElementAccumulator(mK(idx_K, idx_D, idx_L));
|
||||
ElementAcc rK = mK(idx_K, idx_D, idx_L);
|
||||
ElementAcc rDS = mS[idx_K];
|
||||
acc += rDS * rK;
|
||||
}
|
||||
mDQ(idx_Q, idx_D, idx_L) = static_cast<typename TensorDQ::value_type>(acc);
|
||||
} // for idx_D
|
||||
@ -135,62 +137,83 @@ void __global__ fmha_bwd_reference_dK_kernel(
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0 / sqrt(ElementAccumulator(size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [H, B] = get<4>(problem_shape_in);
|
||||
auto [H_R, H_K] = H;
|
||||
|
||||
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
|
||||
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDK = domain_offset(select<1,2,4>(offset), mDK_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
// acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
// acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
} // for idx_D0
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
|
||||
|
||||
for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) {
|
||||
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
|
||||
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
|
||||
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
|
||||
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
|
||||
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
|
||||
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
|
||||
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
|
||||
auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in);
|
||||
|
||||
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
|
||||
ElementAcc acc_dk = 0;
|
||||
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
|
||||
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
|
||||
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
ElementAcc acc_dov = 0;
|
||||
ElementAcc acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
|
||||
acc_qk += rQ * rK;
|
||||
} // for idx_D0
|
||||
|
||||
for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) {
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB);
|
||||
ElementAcc rV = mV(idx_K, idx_D1, coord_HB);
|
||||
ElementAcc rO = mO(idx_Q, idx_D1, coord_HB);
|
||||
acc_dov += rDO * rV;
|
||||
acc_doo += rDO * rO ;
|
||||
}
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo));
|
||||
} // for idx_Q
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int idx_D = threadIdx.x;
|
||||
if (idx_D < D) {
|
||||
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB);
|
||||
ElementAcc rDS = mS[idx_Q];
|
||||
acc_dk += rDS * rQ;
|
||||
}
|
||||
}
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
__syncthreads();
|
||||
} // for idx_H_R
|
||||
|
||||
mS[idx_Q] = static_cast<ElementAccumulator>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
} // for idx_Q
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
acc += mS[idx_Q] * ElementAccumulator(mQ(idx_Q, idx_D, idx_L));
|
||||
}
|
||||
mDK(idx_K, idx_D, idx_L) = static_cast<typename TensorDK::value_type>(acc);
|
||||
} // for idx_D
|
||||
int idx_D = threadIdx.x;
|
||||
if (idx_D < D) {
|
||||
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
|
||||
mDK(idx_K, idx_D, coord_HB) = static_cast<typename TensorDK::value_type>(acc_dk);
|
||||
}
|
||||
} // for idx_K
|
||||
} // for idx_L
|
||||
} // for idx_HB
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -216,54 +239,71 @@ void __global__ fmha_bwd_reference_dV_kernel(
|
||||
using ElementAcc = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAcc softmax_scale = 1.0 / sqrt(ElementAcc(size<2>(problem_shape_in)));
|
||||
ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) {
|
||||
auto [H, B] = get<4>(problem_shape_in);
|
||||
auto [H_R, H_K] = H;
|
||||
|
||||
for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) {
|
||||
auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B));
|
||||
auto [problem_shape, offset] = apply_variable_length_offset(
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, idx2crd(idx_L, get<4>(problem_shape_in)))
|
||||
problem_shape_in,
|
||||
make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B))
|
||||
);
|
||||
// problem_shape = problem_shape_in;
|
||||
// offset = repeat_like(problem_shape_in, _0{});
|
||||
auto mQ = domain_offset(select<0,2,4>(offset), mQ_in);
|
||||
auto mK = domain_offset(select<1,2,4>(offset), mK_in);
|
||||
auto mV = domain_offset(select<1,3,4>(offset), mV_in);
|
||||
auto mO = domain_offset(select<0,3,4>(offset), mO_in);
|
||||
auto mLSE = domain_offset(select<0,4>(offset), mLSE_in);
|
||||
auto mDO = domain_offset(select<0,3,4>(offset), mDO_in);
|
||||
auto mDV = domain_offset(select<1,3,4>(offset), mDV_in);
|
||||
for (int idx_K = blockIdx.x; idx_K < size<1>(problem_shape); idx_K += gridDim.x) {
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(problem_shape); idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
auto [Q, K, D, D_VO, HB] = problem_shape;
|
||||
auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset;
|
||||
|
||||
for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, idx_L);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, idx_L);
|
||||
acc_qk += rQ * rK;
|
||||
} // for idx_D0
|
||||
auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in);
|
||||
auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in);
|
||||
auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in);
|
||||
auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in);
|
||||
auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in);
|
||||
auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in);
|
||||
auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in);
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) {
|
||||
ElementAcc acc_dv = 0;
|
||||
for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) {
|
||||
auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B);
|
||||
for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) {
|
||||
ElementAcc acc_qk = 0;
|
||||
|
||||
mS[idx_Q] = expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L));
|
||||
} // for idx_Q
|
||||
for (int idx_D0 = 0; idx_D0 < D; idx_D0++) {
|
||||
ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB);
|
||||
ElementAcc rK = mK(idx_K, idx_D0, coord_HB);
|
||||
acc_qk += rQ * rK;
|
||||
} // for idx_D0
|
||||
|
||||
__syncthreads();
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAcc>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<3>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(problem_shape); idx_Q++) {
|
||||
ElementAcc rS = static_cast<Element>(mS[idx_Q]);
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D, idx_L);
|
||||
acc += rS * rDO;
|
||||
}
|
||||
mDV(idx_K, idx_D, idx_L) = static_cast<typename TensorDV::value_type>(acc);
|
||||
} // for idx_D
|
||||
mS[idx_Q] = static_cast<Element>(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)));
|
||||
} // for idx_Q
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int idx_D_VO = threadIdx.x;
|
||||
if (idx_D_VO < D_VO) {
|
||||
for (int idx_Q = 0; idx_Q < Q; idx_Q++) {
|
||||
ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB);
|
||||
ElementAcc rP = mS[idx_Q];
|
||||
acc_dv += rP * rDO;
|
||||
}
|
||||
} // for idx_D
|
||||
|
||||
__syncthreads();
|
||||
} // for idx_H_R
|
||||
|
||||
int idx_D_VO = threadIdx.x;
|
||||
if (idx_D_VO < D_VO) {
|
||||
auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B);
|
||||
mDV(idx_K, idx_D_VO, coord_HB) = static_cast<typename TensorDV::value_type>(acc_dv);
|
||||
}
|
||||
} // for idx_K
|
||||
} // for idx_L
|
||||
}
|
||||
@ -288,7 +328,7 @@ void fmha_bwd_reference_dQ(
|
||||
|
||||
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorLSE::value_type);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type);
|
||||
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||
}
|
||||
|
||||
@ -310,9 +350,12 @@ void fmha_bwd_reference_dK(
|
||||
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
|
||||
auto [K, D, HB] = mDK.shape();
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
dim3 grid(K, H_K * B, 1);
|
||||
dim3 block(std::max(D, 256));
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type);
|
||||
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||
}
|
||||
|
||||
@ -334,9 +377,12 @@ void fmha_bwd_reference_dV(
|
||||
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorLSE::value_type);
|
||||
auto [K, D_VO, HB] = mDV.shape();
|
||||
auto [H, B] = HB;
|
||||
auto [H_R, H_K] = H;
|
||||
dim3 grid(K, H_K * B, 1);
|
||||
dim3 block(std::max(D_VO, 256));
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type);
|
||||
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||
}
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized2Sm
|
||||
cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
|
||||
@ -88,7 +88,7 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -189,7 +189,7 @@ cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
@ -283,7 +283,7 @@ struct Result
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -489,19 +489,28 @@ int run(Options &options)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
@ -509,8 +518,8 @@ int main(int argc, char const **args) {
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -530,9 +539,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -86,7 +86,7 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -217,7 +217,7 @@ cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_refer
|
||||
// Matrix-wide normalization constant
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
@ -311,7 +311,7 @@ struct Result
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -536,19 +536,28 @@ int run(Options &options)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
@ -556,8 +565,8 @@ int main(int argc, char const **args) {
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -577,9 +586,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -189,7 +189,7 @@ cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
@ -283,7 +283,7 @@ struct Result
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -489,19 +489,28 @@ int run(Options &options)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
@ -509,8 +518,8 @@ int main(int argc, char const **args) {
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -530,9 +539,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -97,7 +97,7 @@ using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -263,7 +263,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
// NormConst is a single device-side constant value, its not per-batch or per-group
|
||||
cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
@ -466,7 +466,7 @@ struct Result
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -861,30 +861,39 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 ||
|
||||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
|
||||
)
|
||||
) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer.\n";
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120a).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120 or 121).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -901,7 +910,7 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
|
||||
cutlass_example_add_executable(
|
||||
79a_blackwell_geforce_nvfp4_bf16_gemm
|
||||
79a_blackwell_geforce_nvfp4_bf16_gemm.cu
|
||||
|
||||
@ -78,7 +78,7 @@
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -248,7 +248,7 @@ struct Result
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
};
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -507,25 +507,34 @@ int run(Options &options)
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 120.
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
@ -540,9 +549,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
return 0;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -78,7 +78,7 @@
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -183,7 +183,7 @@ cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> bloc
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<outputScaleFactor, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
@ -259,7 +259,7 @@ struct Result
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
};
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -531,25 +531,34 @@ int run(Options &options)
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 120.
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 12 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl;
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
@ -564,9 +573,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
return 0;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 120a)
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
|
||||
cutlass_example_add_executable(
|
||||
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
|
||||
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu
|
||||
|
||||
104
examples/81_blackwell_gemm_blockwise/README.md
Normal file
104
examples/81_blackwell_gemm_blockwise/README.md
Normal file
@ -0,0 +1,104 @@
|
||||
# Blockwise and Groupwise GEMM and Grouped GEMM on Blackwell
|
||||
|
||||
Blockwise and Groupwise GEMM and Grouped GEMM implement software scaling by the accumulator type.
|
||||
The examples in this directory aim to demonstrate how we can instantiate this kernel and run it.
|
||||
The profiler enables instantiating and profiling different kernel configurations for Blockwise and Groupwise GEMM
|
||||
to determine the best performing kernel for your workload.
|
||||
|
||||
## Introduction
|
||||
Blockwise and Groupwise GEMM operations enable fine-grained numerical precision control by applying scale factors at configurable granularities. This is particularly useful for quantized neural networks where different regions of tensors may have different scaling requirements.
|
||||
|
||||
For a GEMM $D = \alpha A B + \beta C$, we introduce two scale factor tensors, SFA
|
||||
and SFB. This leads to a GEMM $D = \alpha \text{SFA} * A \text{ SFB} * B + \beta C$.
|
||||
|
||||
## Scale Factor Tensors
|
||||
- *SFA*: Broadcast the same scale within a block defined by _scale granularity m_ and _scale granularity k_ when scaling A.
|
||||
- Scale granularity m and scale granularity k are also referred to as _scale vector m_ and _k_ respectively.
|
||||
- *SFB*: Broadcast the same scale within a block defined by _scale granularity n_ and _scale granularity k_ when scaling B.
|
||||
- Scale granularity n and scale granularity k are also referred to as _scale vector n_ and _k_ respectively.
|
||||
|
||||
These can be represented in CuTe as:
|
||||
- *SFA Layout*: $((\text{scale granularity M}, M / \text{scale granularity M}), (\text{scale granularity K}, K / \text{scale granularity K})) : ((0, int), (0, int))$
|
||||
- *SFB Layout*: $((\text{scale granularity N}, M / \text{scale granularity M}), (\text{scale granularity K}, K / \text{scale granularity K})) : ((0, int), (0, int))$
|
||||
|
||||
The 0 element stride ensures the same group of coordinates to map to the same element in the scale factors.
|
||||
|
||||
## Configuration
|
||||
|
||||
For convenience the Blockwise and Groupwise implementation provide
|
||||
`cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>`
|
||||
to deduce layouts and manage compact tensors.
|
||||
|
||||
`cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>` by default makes
|
||||
every tensor major the M/N mode, but can be configured. For example:
|
||||
`cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, UMMA::Major::K, UMMA::Major::MN>`
|
||||
denotes SFA will be major in the K dimension but SFB will be major in the N dimension.
|
||||
|
||||
## Integration with Other Frameworks
|
||||
|
||||
If translating from frameworks like Torch where SFA has shape
|
||||
(M / ScaleGranularityM, K / ScaleGranularityK) and SFB has a shape (K / ScaleGranularityK, N / ScaleGranularityN),
|
||||
ensure to transpose SFB and B to fit into the canonical CuTe layout form. This ensures K is always the second mode.
|
||||
Use strides can be used to determine if each tensor is MN or K major to correctly form the layouts either directly
|
||||
or with the convenience wrappers.
|
||||
|
||||
|
||||
## Kernel Selection and Profiling
|
||||
|
||||
To determine the most performance Blockwise/Groupwise GEMM or Grouped GEMM kernel for your use case, you can utilize the
|
||||
[CUTLASS profiler](../../media/docs/cpp/profiler.md).
|
||||
|
||||
All Blockwise/Groupwise GEMMs and Group GEMMs with `f32` scaling of `e4m3` or runtime `f8` types can be selected by
|
||||
selecting a subset of kernels when configuring with CMake by passing:
|
||||
`-DCUTLASS_LIBRARY_KERNELS="cutlass3x*f32xe4m3_*f32xe4m3*,cutlass3x*f32xf8_*f32xf8*"`.
|
||||
|
||||
The simplest way to use the profiler is to pass `m`, `n`, and `k` as well as your `scale_vec_size_m`,
|
||||
`scale_vec_size_n`, and `scale_vec_size_k`. Passing `enable-best-kernel-for-fixed-shape` will do some autotuning
|
||||
per kernel to determine best rasterization orders, swizzles, and cluster sizes. Passing `blockwiseGemm`
|
||||
or `GroupedGemm` through the operation flag will determine which set of operations will be profiled.
|
||||
|
||||
For examle, this command using the cutlass profiler will dump the performance of all compiled kernels which support scale
|
||||
granularity m = 1, scale granularity n = 128, and scale granularity k = 128 for the problem size 8192x8192x8192:
|
||||
```
|
||||
cutlass_profiler --operation=blockwiseGemm \
|
||||
--enable-best-kernel-for-fixed-shape \
|
||||
--m=8192 --n=8192 --k=8192 \
|
||||
--scale_vec_size_m=1 --scale_vec_size_n=128 --scale_vec_size_k=128 \
|
||||
--verification-enabled=false
|
||||
```
|
||||
|
||||
### Kernel Naming Convention
|
||||
|
||||
The naming of the blockwise and groupwise kernels includes the following new pattern: for each tensor scalar pair we have
|
||||
`<scale_granularity_m or scale_granularity_n>x<scale_granularity_k><accumulator type>x<scaled tensor type>`. For example
|
||||
`cutlass3x_sm100_tensorop_gemm_64x128f32xe4m3_1x128f32xe4m3_f32_f16_f16_64x128x128_1x1x1_0_nnn_align16_1sm` would denote:
|
||||
- A CUTLASS 3 GEMM for SM100 that uses tensor cores.
|
||||
- SFA is f32 with a 64 element scale granularity m and a 128 element scale granularity k.
|
||||
- The A matrix is e4m3.
|
||||
- SFB is f32 with a 1 element scale granularity n and a 128 element scale granularity k.
|
||||
- The B matrix is e4m3.
|
||||
- The epilogue is done in f32.
|
||||
- The C matrix is f16.
|
||||
- The D matrix is f16.
|
||||
- The MMA tile shape is 64x128x128.
|
||||
- The cluster shape is 1x1x1.
|
||||
- A, B, C, and D are all column major.
|
||||
- The alignment of the major modes are 16 elements for A, B, C, and D.
|
||||
- The MMA variant is a 1SM instruction.
|
||||
|
||||
It is also worthwhile to note that C can be void if scaling by beta is not needed.
|
||||
|
||||
## Performance Tips and Tricks
|
||||
|
||||
- *MMA Dimensions*: in both Blackwell and Hopper tensor cores it is worthwhile to note that the smallest `MMA_M` dimension is 64, but `MMA_N`
|
||||
dimension can be as small as 8 for some instructions. For problem sizes where M is small consider computing $D^T = \alpha B^T A^T + \beta C^T$ instead.
|
||||
- When computing after swapping A and B and transposing the N dimension is now our small dimension. With a small `MMA_N` we can more effectively tile without performing unecessary computation.
|
||||
- *Layout Swapping*: When optimizing with the profiler swap `m` and `n` inputs and adjust layouts to reflect this swapping and transposing.
|
||||
- For example if we have a row-major A, column-major B, and row-major D, we can swap tensors and run a kernel with:
|
||||
- The left hand matrix as row-major (since B transposed is row-major)
|
||||
- A right hand matrix as column-major (since A transposed is column-major)
|
||||
- A column-major output (since D transposed is column-major).
|
||||
|
||||
When using blockwise and groupwise GEMM we must swap the scale vector sizes when doing this optimization. If we have a 1 element scale granularity M
|
||||
and a 128 element scale granularity N, we must run a kernel with a 128 element scale granularity M and a 1 element scale granularity
|
||||
N.
|
||||
@ -26,7 +26,9 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
82_blackwell_distributed_gemm
|
||||
82_blackwell_distributed_gemm.cu
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -0,0 +1,497 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
#include "cutlass/detail/collective/mixed_input_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "mixed_dtype_helper.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
using AccumulatorType = float;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementZero = MmaType;
|
||||
using ElementScale = MmaType;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = AccumulatorType; // Element type for internal accumulation
|
||||
using ElementCompute = AccumulatorType; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using MmaTileShape = Shape<_256,_128,_128>; // (MmaTileShape_N, MmaTileShape_M, MmaTileShape_K) as A and B will be swapped
|
||||
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmMixedInputSm100; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
constexpr int ScaleGranularityN = 1; //Should be less than or equal to GEMM_N
|
||||
constexpr int ScaleGranularityK = 128; //Should be less than or equal to GEMM_K
|
||||
using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig<ScaleGranularityN, ScaleGranularityK>;
|
||||
using LayoutScale = decltype(ScaleConfig::deduce_layout_scale()); // Layout type for SFA matrix operand
|
||||
LayoutScale layout_S;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C matrix.
|
||||
// We can enable this if beta == 0 by changing ElementC to void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
|
||||
>::CollectiveOp;
|
||||
|
||||
// ============================================================ MIXED INPUT NO SCALES ============================================================================
|
||||
//The collective will infer that the narrow type should be upcasted to the wide type.
|
||||
//We swap A and B operands to the builder here
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB>, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
MainloopSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, cute::tuple<LayoutB_Transpose, LayoutScale>, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
MainloopSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ==================================================================
|
||||
// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format.
|
||||
using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale, ElementZero>, cute::tuple<LayoutB_Transpose, LayoutScale>, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
MainloopSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleWithZeroPoint,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleWithZeroPoint>;
|
||||
// =================================================================================================================================================================
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
using StrideC = typename GemmKernelScaleOnly::StrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::StrideD;
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideC_ref stride_C_ref;
|
||||
StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
// Scale and Zero share a stride since the layout and shapes must be the same.
|
||||
using StrideS = typename cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(MixedDtypeOptions const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
layout_S = ScaleConfig::tile_atom_to_shape_scale(make_shape(options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_S)));
|
||||
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
block_scale.reset(blockscale_b_coord.product());
|
||||
block_zero.reset(blockscale_b_coord.product());
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale<QuantType, ElementScale>(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
if(options.verify){
|
||||
auto layout_B = make_layout(shape_b, stride_B);
|
||||
auto scale_stride = layout_S.stride();
|
||||
auto layout_scale_zero = make_layout(
|
||||
make_shape(size<0>(layout_S), size<1,1>(layout_S), size<2>(layout_S)),
|
||||
make_stride(size<0,1>(scale_stride), size<1,1>(scale_stride), size<2>(scale_stride))
|
||||
); //layout = (options.n, scale_k, options.l) : (_1, options.n, _0)
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, ScaleGranularityK, stream);
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <class Args, cutlass::detail::ConversionMode KernelConversionMode>
|
||||
Args args_from_options(MixedDtypeOptions const& options)
|
||||
{
|
||||
// Swap the A and B tensors, as well as problem shapes here.
|
||||
if constexpr (KernelConversionMode == cutlass::detail::ConversionMode::DirectConvert) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if constexpr(KernelConversionMode == cutlass::detail::ConversionMode::ConvertAndScale) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), layout_S},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if constexpr(KernelConversionMode == cutlass::detail::ConversionMode::ConvertAndScaleWithZero) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), layout_S, block_zero.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
} else {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
bool verify(MixedDtypeOptions const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
constexpr int AlignmentBdq = 128 / cutlass::sizeof_bits<MmaType>::value;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentBdq,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-2f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(MixedDtypeOptions &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments, Gemm::CollectiveMainloop::KernelConversionMode>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
if(options.verify){
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
else{
|
||||
result.passed = true;
|
||||
std::cout << " Verification: Off " << std::endl;
|
||||
}
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
mixed_dtype_profiling(gemm, options, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 100a.
|
||||
bool is_correct_cuda_version = (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 8);
|
||||
if (!is_correct_cuda_version) {
|
||||
std::cerr << "Version is " << __CUDACC_VER_MINOR__ << "\n";
|
||||
std::cerr << "This example requires CUDA 12.8 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture or "
|
||||
<< "later (compute capability 100a or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
MixedDtypeOptions options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in conversion only mode." << std::endl;
|
||||
run<GemmConvertOnly>(options);
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
std::cout << "Running in scale mode." << std::endl;
|
||||
run<GemmScaleOnly>(options);
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
std::cout << "Running in scale and zero mode." << std::endl;
|
||||
run<GemmScaleWithZeroPoint>(options);
|
||||
}
|
||||
else{
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
45
examples/86_blackwell_mixed_dtype_gemm/CMakeLists.txt
Normal file
45
examples/86_blackwell_mixed_dtype_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(TEST_S_TILE_SHAPE --m=256 --n=128 --k=32 --verify --iterations=0)
|
||||
set(TEST_S_TILE_SHAPE_MULTIPLE_KITER --m=256 --n=128 --k=128 --verify --iterations=0)
|
||||
set(TEST_S_DIFFERENT_MN --m=16384 --n=4608 --k=4608 --verify --iterations=0)
|
||||
set(TEST_S_ONE_WAVE --m=1536 --n=1536 --k=32 --verify --iterations=0) # Assuming 144 SMs
|
||||
set(TEST_S_2048 --m=2048 --n=2048 --k=2048 --verify --iterations=0) # Multi-wave
|
||||
|
||||
if(NOT WIN32)
|
||||
cutlass_example_add_executable(
|
||||
86_blackwell_mixed_dtype_gemm
|
||||
86_blackwell_mixed_dtype.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_S_TILE_SHAPE
|
||||
TEST_S_TILE_SHAPE_MULTIPLE_KITER
|
||||
TEST_S_ONE_WAVE
|
||||
TEST_S_2048
|
||||
)
|
||||
endif()
|
||||
269
examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh
Normal file
269
examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh
Normal file
@ -0,0 +1,269 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <numeric>
|
||||
#include "helper.h"
|
||||
|
||||
enum MixedDtypeGemmMode {
|
||||
ConvertOnly,
|
||||
ScaleOnly,
|
||||
ScaleWithZeroPoint
|
||||
};
|
||||
|
||||
/// Command line options parsing
|
||||
struct MixedDtypeOptions {
|
||||
|
||||
bool help = false;
|
||||
bool verify = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int mode = 1;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("verify")) {
|
||||
verify = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("mode", mode);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("warmup", warmup);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "86_blackwell_mixed_dtype_gemm\n\n"
|
||||
<< " Blackwell Mixed Data Type GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform.\n\n"
|
||||
<< " --verify=<int> Run verification.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "86_blackwell_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct MixedDtypeResult
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
|
||||
};
|
||||
|
||||
/// Profiling Loop
|
||||
template <class Gemm>
|
||||
void mixed_dtype_profiling(
|
||||
Gemm& gemm,
|
||||
MixedDtypeOptions const& options,
|
||||
MixedDtypeResult& result) {
|
||||
|
||||
if (options.iterations <= 0) return;
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
std::vector<float> runtimes;
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
|
||||
}
|
||||
|
||||
/// Helpers to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class QuantType, class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(1.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
else {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_zero(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -0,0 +1,518 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief An FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockwise scaling FP8 GEMM on the NVIDIA Blackwell SM120 architecture.
|
||||
This kernel is optimized for the GeForce RTX 50 series GPUs.
|
||||
|
||||
This kernel accepts Inputs A and B with TileMxTileK and TileNxTileK FP32 block scaling, performing scaling and accumulation every TileK elements.
|
||||
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
|
||||
|
||||
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Epilogue Optimization
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
3. Runtime scaling block size.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "./utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
|
||||
using ElementCompute = float;
|
||||
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile
|
||||
using MmaTileShape_MNK = Shape<_128,_128,_128>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
using ScaleConfig = decltype(cutlass::detail::sm120_trivial_blockwise_scale_config(MmaTileShape_MNK{}));
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutC, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
// Strides just iterate over scalars and have no zeros
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
// Layouts are tiled to the problem size and the strides have zeros
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
|
||||
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
|
||||
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool skip_verification = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("skip-verification")) {
|
||||
skip_verification = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "87a_blackwell_geforce_gemm_blockwise\n\n"
|
||||
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --skip-verification Skip verification.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "87a_blackwell_geforce_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
tensor_SFA.resize(blockscale_a_coord);
|
||||
tensor_SFB.resize(blockscale_b_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
|
||||
|
||||
initialize_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
|
||||
initialize_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
|
||||
tensor_SFA.sync_device();
|
||||
tensor_SFB.sync_device();
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A,
|
||||
tensor_B.device_data(), stride_B,
|
||||
tensor_SFA.device_data(), layout_SFA,
|
||||
tensor_SFB.device_data(), layout_SFB},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <class Gemm>
|
||||
int run(Options &options) {
|
||||
initialize(options);
|
||||
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
Result result;
|
||||
if (!options.skip_verification) {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,539 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief An FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run cooperative and ping-pong groupwise scaling FP8 GEMMs on the NVIDIA Blackwell SM120 architecture.
|
||||
These kernels are optimized for GeForce RTX 50 series GPUs.
|
||||
|
||||
The blockscaling kernels accept Inputs A and B with 1xTileK and TileNxTileK FP32 block scaling, performing scaling and accumulation every TileK elements.
|
||||
The ping-pong kernel leverages a smaller tile shape to avoid register spilling for better performance.
|
||||
Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages:
|
||||
|
||||
1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper.
|
||||
2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
3. Epilogue Optimization
|
||||
|
||||
Note that GeForce RTX 50 series GPUs do not support:
|
||||
1. Multicast feature of TMA load. Cluster shape has to be 1x1x1.
|
||||
2. Dynamic datatypes.
|
||||
3. Runtime scaling block size.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/87_blackwell_geforce_gemm_blockwise/87b_blackwell_geforce_fp8_bf16_gemm_groupwise --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "./utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
|
||||
using ElementCompute = float;
|
||||
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile
|
||||
using CooperativeMmaTileShape_MNK = Shape<_128,_128,_128>;
|
||||
// Smaller tile size for pingpong schedule to avoid register spilling
|
||||
using PingpongMmaTileShape_MNK = Shape<_64, _128, _128>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
template <class TileShape>
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutC, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
template <class TileShape, class Schedule>
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue<TileShape>::SharedStorage))>,
|
||||
Schedule // cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120
|
||||
>::CollectiveOp;
|
||||
|
||||
template <class TileShape, class Schedule>
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop<TileShape, Schedule>,
|
||||
CollectiveEpilogue<TileShape>,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
// We are using cooperative kernel schedule by default
|
||||
using CooperativeGemm = cutlass::gemm::device::GemmUniversalAdapter<
|
||||
GemmKernel<CooperativeMmaTileShape_MNK, cutlass::gemm::KernelScheduleSm120Blockwise>>;
|
||||
|
||||
// Pingpong kernel
|
||||
using PingpongGemm = cutlass::gemm::device::GemmUniversalAdapter<
|
||||
GemmKernel<PingpongMmaTileShape_MNK, cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120>>;
|
||||
|
||||
using StrideA = typename CooperativeGemm::GemmKernel::StrideA;
|
||||
using StrideB = typename CooperativeGemm::GemmKernel::StrideB;
|
||||
using StrideC = typename CooperativeGemm::GemmKernel::StrideC;
|
||||
using StrideD = typename CooperativeGemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
// Strides just iterate over scalars and have no zeros
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
// Layouts are tiled to the problem size and the strides have zeros
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
|
||||
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
|
||||
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool skip_verification = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("skip-verification")) {
|
||||
skip_verification = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "87b_blackwell_geforce_gemm_groupwise\n\n"
|
||||
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --skip-verification Skip verification.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "87b_blackwell_geforce_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
tensor_SFA.resize(blockscale_a_coord);
|
||||
tensor_SFB.resize(blockscale_b_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
|
||||
|
||||
initialize_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
|
||||
initialize_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
|
||||
tensor_SFA.sync_device();
|
||||
tensor_SFB.sync_device();
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <class Gemm>
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A,
|
||||
tensor_B.device_data(), stride_B,
|
||||
tensor_SFA.device_data(), layout_SFA,
|
||||
tensor_SFB.device_data(), layout_SFB},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <class Gemm>
|
||||
int run(Options &options) {
|
||||
initialize(options);
|
||||
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
Result result;
|
||||
if (!options.skip_verification) {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
printf("Running kernel with Cooperative kernel schedule:\n");
|
||||
run<CooperativeGemm>(options);
|
||||
|
||||
printf("Running kernel with Pingpong kernel schedule:\n");
|
||||
run<PingpongGemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,678 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief An FP8 groupwise scaled grouped GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM120 TensorOp-based warp-specialized kernel
|
||||
for FP8 with per-group:1x128x128 FP32 scaling factors.
|
||||
In this example, M, N, and K are fixed across groups.
|
||||
As RTX 50 series GPUs do not support runtime scaling block sizes, all groups share the same block scaling size.
|
||||
For this example all scheduling work is performed on the device, utilizing the device-side modification of TMA descriptors
|
||||
to move between groups/problem_count (represented by groups).
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/87_blackwell_geforce_gemm_blockwise/87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "./utils.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
|
||||
using ElementCompute = float;
|
||||
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile
|
||||
using MmaTileShape_MNK = Shape<_128,_128,_128>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
// Scaling Factors
|
||||
using ElementSF = ElementAccumulator;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutD *, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelScheduleSm120Blockwise
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA, LayoutSFA>);
|
||||
static_assert(cute::is_same_v<typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB, LayoutSFB>);
|
||||
|
||||
|
||||
/// Initialization
|
||||
uint64_t seed;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
using HostTensorA = cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorB = cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorC = cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorD = cutlass::HostTensor<Gemm::EpilogueOutputOp::ElementOutput, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorSFA = cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorSFB = cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
std::vector<HostTensorA> block_A;
|
||||
std::vector<HostTensorB> block_B;
|
||||
std::vector<HostTensorC> block_C;
|
||||
std::vector<HostTensorD> block_D;
|
||||
std::vector<HostTensorD> block_ref_D;
|
||||
std::vector<HostTensorSFA> block_SFA;
|
||||
std::vector<HostTensorSFB> block_SFB;
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
cutlass::DeviceAllocation<ElementA const*> ptr_A;
|
||||
cutlass::DeviceAllocation<ElementB const*> ptr_B;
|
||||
cutlass::DeviceAllocation<ElementSF const*> ptr_SFA;
|
||||
cutlass::DeviceAllocation<ElementSF const*> ptr_SFB;
|
||||
cutlass::DeviceAllocation<ElementC const*> ptr_C;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_D;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool skip_verification = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1, groups = 10;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::AlongN;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("skip-verification")) {
|
||||
skip_verification = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char, 'N');
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
} else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "87c_blackwell_geforce_grouped_gemm_groupwise\n\n"
|
||||
<< " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --skip-verification Skip verification.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "87c_blackwell_geforce_grouped_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --groups=8 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * groups;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementSF *> ptr_SFA_host(options.groups);
|
||||
std::vector<ElementSF *> ptr_SFB_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementD *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
for (int i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto [M, N, K] = problem;
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
|
||||
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
||||
|
||||
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
|
||||
auto layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
|
||||
auto layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
|
||||
|
||||
stride_A_host.push_back(stride_A);
|
||||
stride_B_host.push_back(stride_B);
|
||||
layout_SFA_host.push_back(layout_SFA);
|
||||
layout_SFB_host.push_back(layout_SFB);
|
||||
stride_C_host.push_back(stride_C);
|
||||
stride_D_host.push_back(stride_D);
|
||||
|
||||
block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A))));
|
||||
block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B))));
|
||||
block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C))));
|
||||
block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
|
||||
block_SFA.push_back(HostTensorSFA(cutlass::make_Coord(size(filter_zeros(layout_SFA)))));
|
||||
block_SFB.push_back(HostTensorSFB(cutlass::make_Coord(size(filter_zeros(layout_SFB)))));
|
||||
block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.groups; ++i) {
|
||||
initialize_tensor(block_A.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2022);
|
||||
initialize_tensor(block_B.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2023);
|
||||
initialize_tensor(block_C.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2024);
|
||||
initialize_tensor(block_SFA.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2025);
|
||||
initialize_tensor(block_SFB.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2026);
|
||||
|
||||
block_A.at(i).sync_device();
|
||||
block_B.at(i).sync_device();
|
||||
block_C.at(i).sync_device();
|
||||
block_SFA.at(i).sync_device();
|
||||
block_SFB.at(i).sync_device();
|
||||
|
||||
ptr_A_host.at(i) = block_A.at(i).device_data();
|
||||
ptr_B_host.at(i) = block_B.at(i).device_data();
|
||||
ptr_C_host.at(i) = block_C.at(i).device_data();
|
||||
ptr_D_host.at(i) = block_D.at(i).device_data();
|
||||
ptr_SFA_host.at(i) = block_SFA.at(i).device_data();
|
||||
ptr_SFB_host.at(i) = block_SFB.at(i).device_data();
|
||||
|
||||
alpha_host.push_back((options.alpha == std::numeric_limits<float>::max()) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == std::numeric_limits<float>::max()) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_SFA.reset(options.groups);
|
||||
ptr_SFA.copy_from_host(ptr_SFA_host.data());
|
||||
|
||||
ptr_SFB.reset(options.groups);
|
||||
ptr_SFB.copy_from_host(ptr_SFB_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = options.raster_order;
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
if (options.alpha != std::numeric_limits<float>::max()) {
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 0};
|
||||
} else {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
}
|
||||
|
||||
if (options.beta != std::numeric_limits<float>::max()) {
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dBeta = {_0{}, _0{}, 0};
|
||||
} else {
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
fusion_args.dBeta = {_0{}, _0{}, 1};
|
||||
}
|
||||
|
||||
arguments = {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(),
|
||||
ptr_B.get(), stride_B.get(),
|
||||
ptr_SFA.get(), layout_SFA.get(),
|
||||
ptr_SFB.get(), layout_SFB.get()},
|
||||
{
|
||||
fusion_args,
|
||||
ptr_C.get(), stride_C.get(),
|
||||
ptr_D.get(), stride_D.get()
|
||||
},
|
||||
hw_info, scheduler
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
bool passed = true;
|
||||
|
||||
for (int i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto [M, N, K] = problem;
|
||||
|
||||
auto A = cute::make_tensor(block_A.at(i).host_data(),
|
||||
cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i)));
|
||||
auto B = cute::make_tensor(block_B.at(i).host_data(),
|
||||
cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i)));
|
||||
auto C = cute::make_tensor(block_C.at(i).host_data(),
|
||||
cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i)));
|
||||
auto D = cute::make_tensor(block_ref_D.at(i).host_data(),
|
||||
cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i)));
|
||||
auto SFA = cute::make_tensor(block_SFA.at(i).host_data(), layout_SFA_host.at(i));
|
||||
auto SFB = cute::make_tensor(block_SFB.at(i).host_data(), layout_SFB_host.at(i));
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = alpha_host.at(i);
|
||||
epilogue_params.beta = beta_host.at(i);
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
|
||||
block_D.at(i).sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view());
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <class Gemm>
|
||||
int run(Options &options) {
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
Result result;
|
||||
if (!options.skip_verification) {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << " " << options.groups << " Groups" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support,
|
||||
// or CUDA 12.9 or higher for SM121 support.
|
||||
// Must have compute capability at least 120.
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) {
|
||||
std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
initialize(options);
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
47
examples/87_blackwell_geforce_gemm_blockwise/CMakeLists.txt
Normal file
47
examples/87_blackwell_geforce_gemm_blockwise/CMakeLists.txt
Normal file
@ -0,0 +1,47 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
|
||||
cutlass_example_add_executable(
|
||||
87a_blackwell_geforce_fp8_bf16_gemm_blockwise
|
||||
87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
87b_blackwell_geforce_fp8_bf16_gemm_groupwise
|
||||
87b_blackwell_geforce_fp8_bf16_gemm_groupwise.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise
|
||||
87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
83
examples/87_blackwell_geforce_gemm_blockwise/utils.h
Normal file
83
examples/87_blackwell_geforce_gemm_blockwise/utils.h
Normal file
@ -0,0 +1,83 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_input == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -137,6 +137,9 @@ struct FmhaKernelTma {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
// Shared memory.
|
||||
@ -216,6 +219,7 @@ struct FmhaKernelTma {
|
||||
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||
params.problem_size, params.epilogue,
|
||||
epi_load_pipeline, storage.epilogue);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -160,6 +160,9 @@ struct FmhaKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
@ -412,6 +415,7 @@ struct FmhaKernelTmaWarpSpecialized {
|
||||
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
545
examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm.cu
Normal file
545
examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm.cu
Normal file
@ -0,0 +1,545 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM103 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled 3xFP4 GEMM on the NVIDIA Blackwell SM103 architecture.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e2m1_t; // Element type for A matrix operand
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e2m1_t; // Element type for A matrix operand
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm103; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = cute::Shape<cute::_128, cute::_128, Int<768>>; // MMA's tile size
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_4, cute::_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Epilogue fusion operator
|
||||
using EpilogueFusionOp = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||
EpilogueFusionOp
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementA,ElementSFA>, LayoutATag, AlignmentA,
|
||||
cute::tuple<ElementB,ElementSFB>, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementSFA, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementSFB, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "89_sm103_fp4_ultra_gemm\n\n"
|
||||
<< " Sm103 3xFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
uint8_t* workspace = nullptr;
|
||||
cudaError_t status = cudaMalloc(&workspace, workspace_size);
|
||||
if (status != cudaSuccess) {
|
||||
std::cerr << "Failed to allocate workspace memory: " << cudaGetErrorString(status) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Free workspace memory
|
||||
cudaFree(workspace);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.9 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) {
|
||||
std::cerr << "This example requires CUDA 12.9 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 3)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 103)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
38
examples/89_sm103_fp4_ultra_gemm/CMakeLists.txt
Normal file
38
examples/89_sm103_fp4_ultra_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 103a)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
89_sm103_fp4_ultra_gemm
|
||||
89_sm103_fp4_ultra_gemm.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
File diff suppressed because it is too large
Load Diff
66
examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt
Normal file
66
examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --beta=2.0 --k=512 --groups=51 --iterations=0)
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --beta=0.5 --groups=50 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
|
||||
set(TEST_RANDOM_SMALL_GROUP --groups=3 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_SMALL_GROUP --alpha=1.5 --beta=2.0 --groups=3 --iterations=1) # Random problem sizes
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 103a)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
90_sm103_fp4_ultra_grouped_gemm
|
||||
90_sm103_fp4_ultra_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_SMALL_GROUP
|
||||
TEST_EPILOGUE_SMALL_GROUP
|
||||
)
|
||||
|
||||
endif()
|
||||
898
examples/91_fp4_gemv/91_fp4_gemv.cu
Normal file
898
examples/91_fp4_gemv/91_fp4_gemv.cu
Normal file
@ -0,0 +1,898 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <cstdint> // uint64_t
|
||||
#include <cstdio>
|
||||
#include <cstdlib> // rand(), RAND_MAX
|
||||
#include <string> // std::stoi
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <float.h>
|
||||
#include <optional>
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp" // FIX cute header file inclusion issue
|
||||
// clang-format on
|
||||
|
||||
#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::Major
|
||||
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
|
||||
#include "cutlass/complex.h" // cutlass::ComplexTransform
|
||||
#include "cutlass/cutlass.h" // cutlass::Status
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp" // cutlass::detail::Sm1xxBlockScaledOutputConfig
|
||||
#include "cutlass/epilogue/thread/linear_combination.h" // cutlass::epilogue::thread::LinearCombination
|
||||
#include "cutlass/gemm/device/gemv_blockscaled.h" // cutlass::gemm::device::Gemv
|
||||
#include "cutlass/gemm/kernel/gemv_blockscaled.h" // cutlass::gemm::kernel::Gemv
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h" // cutlass::epilogue::threadblock::GemvEpilogueWithScalingFactor
|
||||
#include "cutlass/gemm_coord.h" // cutlass::GemmCoord
|
||||
#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory
|
||||
#include "cutlass/numeric_size.h" // cutlss::is_subbyte
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/platform/platform.h" // cutlass::is_same_v
|
||||
#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation
|
||||
#include "cutlass/util/distribution.h" // cutlass::Distribution
|
||||
#include "cutlass/util/host_tensor.h" // cutlass::HostTensor
|
||||
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
|
||||
#include "cutlass/util/reference/host/gemm_complex.h" // cutlass::reference::host::GemmComplex
|
||||
#include <cutlass/util/reference/host/gett.hpp> // cutlass::reference::host::GettBlockScalingMainloopParams
|
||||
// cutlass::reference::host::GettBlockScalingEpilogueParams
|
||||
// cutlass::reference::host::Gemm3x
|
||||
#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals
|
||||
#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
|
||||
// Helper Functions
|
||||
template <typename T>
|
||||
auto
|
||||
make_iterator(T* ptr)
|
||||
{
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
template <typename Element, typename Layout>
|
||||
bool
|
||||
initialize_tensor(cutlass::TensorView<Element, Layout> view, cutlass::Distribution::Kind dist_kind, uint64_t seed)
|
||||
{
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_input <= 8) {
|
||||
if constexpr (cutlass::is_same_v<Element, cutlass::float_ue4m3_t> ||
|
||||
cutlass::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
} else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
} else {
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
|
||||
else {
|
||||
CUTLASS_ASSERT(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Base class of Testbed
|
||||
template <
|
||||
typename Gemv_,
|
||||
// The following types are more difficult to be derived from EVT
|
||||
typename ElementC, typename LayoutC, typename ElementD_,
|
||||
typename LayoutD, typename ElementSFD_, typename LayoutSFD,
|
||||
typename ElementCompute_, int kVectorSize_>
|
||||
struct TestbedGemvFp4SFDBase
|
||||
{
|
||||
public:
|
||||
using Gemv = Gemv_;
|
||||
|
||||
using ElementA = typename Gemv::ElementA;
|
||||
using ElementSFA = typename Gemv::ElementSFA;
|
||||
using LayoutA = typename Gemv::LayoutA;
|
||||
static_assert(cutlass::is_same_v<LayoutA, cutlass::layout::RowMajor>, "only support row major matrix A");
|
||||
static_assert(cutlass::sizeof_bits<ElementSFA>::value == 8, "ElementSFA should be FP8 type");
|
||||
|
||||
using ElementB = typename Gemv::ElementB;
|
||||
using ElementSFB = typename Gemv::ElementSFB;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static_assert(cutlass::is_same_v<ElementA, ElementB>, "only support ElementA ElementB of same type");
|
||||
static_assert(cutlass::sizeof_bits<ElementSFB>::value == 8, "ElementSFB should be FP8 type");
|
||||
|
||||
static_assert(cutlass::is_same_v<LayoutC, cutlass::layout::ColumnMajor>, "only support col major output D");
|
||||
|
||||
using ElementD = ElementD_;
|
||||
static_assert(cutlass::is_same_v<LayoutD, cutlass::layout::ColumnMajor>, "only support col major output D");
|
||||
|
||||
using ElementSFD = ElementSFD_;
|
||||
static_assert(cutlass::is_same_v<LayoutSFD, cutlass::layout::ColumnMajor>, "only support col major output SFD");
|
||||
static_assert(cutlass::sizeof_bits<ElementSFD>::value, "only support 8 bit SFD");
|
||||
|
||||
using ElementAccumulator = typename Gemv::ElementAccumulator;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static_assert(cutlass::is_same_v<ElementCompute, float>, "only support fp32 epi compute");
|
||||
|
||||
static constexpr int kVectorSize = kVectorSize_;
|
||||
static_assert(kVectorSize == 16, "only support vs 16");
|
||||
|
||||
// SFD Config
|
||||
static constexpr bool kIsKMajorSFD = cutlass::is_same_v<LayoutSFD, cutlass::layout::RowMajor>;
|
||||
using Sm1xxBlockScaledOutputConfig=
|
||||
cutlass::detail::Sm1xxBlockScaledOutputConfig<kVectorSize,
|
||||
kIsKMajorSFD ? cute::UMMA::Major::K : cute::UMMA::Major::MN>;
|
||||
using Blk_MN_Output = typename Sm1xxBlockScaledOutputConfig::Blk_MN;
|
||||
using Blk_SF_Output = typename Sm1xxBlockScaledOutputConfig::Blk_SF;
|
||||
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
|
||||
|
||||
// SFA SFB Config
|
||||
using Sm100BlockScaledInputConfig = cutlass::detail::Sm1xxBlockScaledConfig<kVectorSize>;
|
||||
using Blk_MN_Input = typename Sm100BlockScaledInputConfig::Blk_MN;
|
||||
using Blk_SF_Input = typename Sm100BlockScaledInputConfig::Blk_SF;
|
||||
using SfAtom_Input = typename Sm100BlockScaledInputConfig::SfAtom;
|
||||
|
||||
public:
|
||||
TestbedGemvFp4SFDBase(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_SFA_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_SFB_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_SFD_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2023)
|
||||
: init_A(init_A_)
|
||||
, init_B(init_B_)
|
||||
, init_C(init_C_)
|
||||
, init_D(init_D_)
|
||||
, init_SFA(init_SFA_)
|
||||
, init_SFB(init_SFB_)
|
||||
, init_SFD(init_SFD_)
|
||||
, seed(seed_)
|
||||
{
|
||||
}
|
||||
|
||||
bool initialize(cutlass::MatrixCoord problem_size, int32_t batch_count)
|
||||
{
|
||||
const int32_t gemm_m = problem_size.row();
|
||||
const int32_t gemm_k = problem_size.column();
|
||||
const int32_t gemm_n = 1;
|
||||
const int32_t gemm_batch = batch_count;
|
||||
|
||||
// Resize Config SFA/SFB
|
||||
auto k_blks_input = cutlass::ceil_div(gemm_k, cute::size<1>(shape(SfAtom_Input{})));
|
||||
auto m_blks_input = cutlass::ceil_div(gemm_m, Blk_MN_Input{});
|
||||
auto n_blks_input = cutlass::ceil_div(gemm_n, Blk_MN_Input{});
|
||||
|
||||
auto sfa_coord = cutlass::make_Coord(m_blks_input * Blk_MN_Input{} * gemm_batch, k_blks_input * Blk_SF_Input{});
|
||||
auto sfb_coord = cutlass::make_Coord(n_blks_input * Blk_MN_Input{} * gemm_batch, k_blks_input * Blk_SF_Input{});
|
||||
|
||||
auto sfa_resize_layout =
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutA>::layout_factory(sfa_coord, typename LayoutA::Stride{});
|
||||
auto sfb_resize_layout =
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutB>::layout_factory(sfb_coord, typename LayoutB::Stride{});
|
||||
|
||||
// Use the same SFD layout generation as reference for tensor creation
|
||||
using ProblemShapeType = cute::Shape<int, int, int, int>;
|
||||
auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch};
|
||||
|
||||
// Generate the same layout as reference uses
|
||||
auto sfd_layout = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL);
|
||||
|
||||
// Extract size from the generated layout and create coordinate
|
||||
auto sfd_size = cute::size(cute::filter_zeros(sfd_layout));
|
||||
auto sfd_coord = cutlass::make_Coord(sfd_size, 1); // Linear layout for HostTensor
|
||||
|
||||
auto sfd_resize_layout =
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutSFD>::layout_factory(sfd_coord, typename LayoutSFD::Stride{});
|
||||
|
||||
// Resize Host
|
||||
this->reference_D.resize({gemm_batch * gemm_m, 1}); // D col major vector
|
||||
this->reference_SFD.resize(sfd_coord, sfd_resize_layout);
|
||||
|
||||
if (initialize_tensor(this->reference_D.host_view(), this->init_D, this->seed + 7) == false) {
|
||||
printf("initialize_tensor() REF D failed\n");
|
||||
return false;
|
||||
}
|
||||
if (initialize_tensor(this->reference_SFD.host_view(), this->init_SFD, this->seed + 9) == false) {
|
||||
printf("initialize_tensor() REF SFD failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Resize A/B/C/D
|
||||
this->tensor_A.resize({gemm_batch * gemm_m, gemm_k}); // A row major
|
||||
this->tensor_B.resize({gemm_batch * gemm_k, 1}); // B col major vector
|
||||
this->tensor_C.resize({gemm_batch * gemm_m, 1}); // C col major vector
|
||||
this->tensor_D.resize({gemm_batch * gemm_m, 1}); // D col major vector
|
||||
this->tensor_SFA.resize(sfa_coord, sfa_resize_layout);
|
||||
this->tensor_SFB.resize(sfb_coord, sfb_resize_layout);
|
||||
this->tensor_SFD.resize(sfd_coord, sfd_resize_layout);
|
||||
|
||||
// Fill A/B/C
|
||||
if (initialize_tensor(this->tensor_A.host_view(), this->init_A, this->seed + 1) == false) {
|
||||
printf("initialize_tensor() A failed\n");
|
||||
return false;
|
||||
}
|
||||
if (initialize_tensor(this->tensor_B.host_view(), this->init_B, this->seed + 2) == false) {
|
||||
printf("initialize_tensor() B failed\n");
|
||||
return false;
|
||||
}
|
||||
if (initialize_tensor(this->tensor_C.host_view(), this->init_C, this->seed + 3) == false) {
|
||||
printf("initialize_tensor() C failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fill SFA/SFB
|
||||
if (initialize_tensor(this->tensor_SFA.host_view(), this->init_SFA, this->seed + 4) == false) {
|
||||
printf("initialize_tensor() SFA failed\n");
|
||||
return false;
|
||||
}
|
||||
if (initialize_tensor(this->tensor_SFB.host_view(), this->init_SFB, this->seed + 5) == false) {
|
||||
printf("initialize_tensor() SFB failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fill D/SFD
|
||||
if (initialize_tensor(this->tensor_D.host_view(), this->init_D, this->seed + 6) == false) {
|
||||
printf("initialize_tensor() D failed\n");
|
||||
return false;
|
||||
}
|
||||
if (initialize_tensor(this->tensor_SFD.host_view(), this->init_SFD, this->seed + 8) == false) {
|
||||
printf("initialize_tensor() SFD failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy A/B/C from host to device
|
||||
this->tensor_A.sync_device();
|
||||
this->tensor_B.sync_device();
|
||||
this->tensor_C.sync_device();
|
||||
this->tensor_D.sync_device();
|
||||
this->tensor_SFA.sync_device();
|
||||
this->tensor_SFB.sync_device();
|
||||
this->tensor_SFD.sync_device();
|
||||
|
||||
// SFD initialization is different.
|
||||
// Init referenceSFD on host first, and then copy data to tensorSFD device side.
|
||||
// This ensures tensorSFD and referenceSFD to have same data,
|
||||
// otherwise the "bubbles" due to SFD layouts can lead to false negative sanity check.
|
||||
cutlass::device_memory::copy_to_host(this->reference_SFD.host_data(), this->tensor_SFD.device_data(), sfd_size);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compare_reference()
|
||||
{
|
||||
// device -> host
|
||||
this->tensor_D.sync_host();
|
||||
|
||||
bool passed = true;
|
||||
|
||||
// Check
|
||||
passed = cutlass::reference::host::TensorEquals(this->reference_D.host_view(), this->tensor_D.host_view());
|
||||
if (passed == false) {
|
||||
printf("gemm_m: %d, gemm_k: %d, ", this->tensor_A.host_view().extent(0), this->tensor_A.host_view().extent(1));
|
||||
printf("tensorD mismatch\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
this->tensor_SFD.sync_host();
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(this->reference_SFD.host_view(), this->tensor_SFD.host_view());
|
||||
if (passed == false) {
|
||||
printf("gemm_m: %d, gemm_k: %d, ", this->tensor_A.host_view().extent(0), this->tensor_A.host_view().extent(1));
|
||||
printf("tensorSFD mismatch\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_reference(cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
float epilogue_st)
|
||||
{
|
||||
const int32_t gemm_m = problem_size.row();
|
||||
const int32_t gemm_k = problem_size.column();
|
||||
const int32_t gemm_n = 1;
|
||||
const int32_t gemm_batch = batch_count;
|
||||
|
||||
// Run reference blockscale GETT
|
||||
using ProblemShapeType = cute::Shape<int, int, int, int>;
|
||||
auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch};
|
||||
auto SfD = make_tensor(make_iterator(this->reference_SFD.host_data()),
|
||||
Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL));
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
|
||||
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
|
||||
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutD>;
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(gemm_m, gemm_k, gemm_batch));
|
||||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(gemm_n, gemm_k, gemm_batch));
|
||||
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(gemm_m, gemm_n, gemm_batch));
|
||||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(gemm_m, gemm_n, gemm_batch));
|
||||
|
||||
auto A = make_tensor(make_iterator(this->tensor_A.host_data()),
|
||||
cute::make_layout(cute::make_shape(gemm_m, gemm_k, gemm_batch), stride_a));
|
||||
auto B = make_tensor(make_iterator(this->tensor_B.host_data()),
|
||||
cute::make_layout(cute::make_shape(gemm_n, gemm_k, gemm_batch), stride_b));
|
||||
|
||||
auto C = cute::make_tensor(make_iterator(this->tensor_C.host_data()),
|
||||
cute::make_layout(cute::make_shape(gemm_m, gemm_n, gemm_batch), stride_c));
|
||||
auto D = cute::make_tensor(make_iterator(this->reference_D.host_data()),
|
||||
cute::make_layout(cute::make_shape(gemm_m, gemm_n, gemm_batch), stride_d));
|
||||
|
||||
auto layout_sfa = Sm100BlockScaledInputConfig::tile_atom_to_shape_SFA(problem_shape_MNKL);
|
||||
auto layout_sfb = Sm100BlockScaledInputConfig::tile_atom_to_shape_SFB(problem_shape_MNKL);
|
||||
|
||||
auto SfA = make_tensor(this->tensor_SFA.host_data(), layout_sfa);
|
||||
auto SfB = make_tensor(this->tensor_SFB.host_data(), layout_sfb);
|
||||
|
||||
// Internally scale factor of mainloop will be disabled when ElementA/B == ElementSFA/B.
|
||||
typename cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator, // ElementAccumulator
|
||||
decltype(A), // TensorA
|
||||
decltype(SfA), // TensorSfA
|
||||
decltype(B), // TensorB
|
||||
decltype(SfB) // TensorSfB
|
||||
>
|
||||
mainloop_params{A, SfA, B, SfB};
|
||||
|
||||
typename cutlass::reference::host::GettBlockScalingEpilogueParams<ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(C), // TensorC
|
||||
decltype(D), // TensorD
|
||||
decltype(SfD), // TensorSfD
|
||||
cute::Int<kVectorSize>, // OutputVectorSize
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
>
|
||||
epilogue_params{alpha, beta, C, D, SfD, epilogue_st};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual typename Gemv::Arguments get_arguments(
|
||||
cutlass::MatrixCoord problem_size, int32_t batch_count,
|
||||
float epilogue_st, ElementCompute alpha, ElementCompute beta) = 0;
|
||||
|
||||
bool run_gemv(cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
[[maybe_unused]] float epilogue_st,
|
||||
bool is_profiling,
|
||||
int kIterations)
|
||||
{
|
||||
|
||||
// Not support batch input for testing
|
||||
const int32_t gemm_m = problem_size.row();
|
||||
const int32_t gemm_k = problem_size.column();
|
||||
[[maybe_unused]] const int32_t gemm_n = 1;
|
||||
[[maybe_unused]] const int32_t gemm_batch = batch_count;
|
||||
|
||||
Gemv gemv_op;
|
||||
typename Gemv::Arguments arguments = this->get_arguments(
|
||||
problem_size, batch_count, epilogue_st, alpha, beta
|
||||
);
|
||||
|
||||
cutlass::Status status = gemv_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
printf("can_implement() failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t workspace_size = Gemv::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
status = gemv_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
printf("initialize() failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (not is_profiling) {
|
||||
status = gemv_op();
|
||||
}
|
||||
// profiling
|
||||
else {
|
||||
cudaError_t result;
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (cudaEvent_t &evt : events) {
|
||||
result = cudaEventCreate(&evt);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// warmup
|
||||
status = gemv_op();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Device execution failed on warmup." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
result = cudaEventRecord(events[0]);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int iter_i = 0; iter_i < kIterations; ++iter_i) {
|
||||
status = gemv_op();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Device execution failed." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
result = cudaEventRecord(events[1]);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
float elapsed_ms = 0;
|
||||
result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (cudaEvent_t &evt : events) {
|
||||
result = cudaEventDestroy(evt);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t flops = int64_t(gemm_m) * gemm_n * gemm_k * 2;
|
||||
int64_t bytes = cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementA>) * int64_t(gemm_m) * int64_t(gemm_k)) +
|
||||
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementB>) * int64_t(gemm_k) * int64_t(gemm_n)) +
|
||||
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementD>) * int64_t(gemm_m) * int64_t(gemm_n)) +
|
||||
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementSFA>) * int64_t(gemm_m) * int64_t(gemm_k) / int64_t(kVectorSize)) +
|
||||
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementSFB>) * int64_t(gemm_k) * int64_t(gemm_n) / int64_t(kVectorSize)) +
|
||||
cutlass::bits_to_bytes<int64_t>(int64_t(cute::sizeof_bits_v<ElementSFD>) * int64_t(gemm_m) * int64_t(gemm_n) / int64_t(kVectorSize));
|
||||
|
||||
double gflops_per_second = double(flops) * kIterations * gemm_batch / double(elapsed_ms / 1000.0f) / double(1.0e9);
|
||||
double gbytes_per_second = double(bytes) * kIterations * gemm_batch / double(elapsed_ms / 1000.0f) / double(1 << 30);
|
||||
double elapsed_ms_per_iter = double(elapsed_ms) / kIterations;
|
||||
|
||||
std::cout << " Problem: "
|
||||
<< gemm_m << "-by-" << gemm_n << "-by-" << gemm_k
|
||||
<< ", batch size: " << gemm_batch
|
||||
<< std::endl;
|
||||
std::cout << " Runtime: " << elapsed_ms_per_iter << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
|
||||
std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl;
|
||||
|
||||
}
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
printf("gemv exec failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run_and_verify(cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
float epilogue_st)
|
||||
{
|
||||
|
||||
// Initialize Data
|
||||
if (this->initialize(problem_size, batch_count) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run GEMV kernel
|
||||
if (this->run_gemv(problem_size, batch_count, alpha, beta, epilogue_st, false /*is_profiling*/, 1) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run Reference Kernel
|
||||
if (this->run_reference(problem_size, batch_count, alpha, beta, epilogue_st) == false) {
|
||||
printf("run_reference() failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify
|
||||
if (this->compare_reference() == false) {
|
||||
printf("compare_reference() failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool profile(cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
float epilogue_st,
|
||||
int kIterations = 10)
|
||||
{
|
||||
// Initialize Data
|
||||
if (this->initialize(problem_size, batch_count) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Profile GEMV kernel
|
||||
if (this->run_gemv(problem_size, batch_count, alpha, beta, epilogue_st, true /*is_profiling*/, kIterations) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
// Data Storage
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementSFA, LayoutA> tensor_SFA;
|
||||
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementSFB, LayoutB> tensor_SFB;
|
||||
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
|
||||
cutlass::HostTensor<ElementD, LayoutD> tensor_D;
|
||||
cutlass::HostTensor<ElementSFD, LayoutD> tensor_SFD;
|
||||
|
||||
cutlass::HostTensor<ElementD, LayoutD> reference_D;
|
||||
cutlass::HostTensor<ElementSFD, LayoutD> reference_SFD;
|
||||
|
||||
// Data Init Setting
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_D;
|
||||
cutlass::Distribution::Kind init_SFA;
|
||||
cutlass::Distribution::Kind init_SFB;
|
||||
cutlass::Distribution::Kind init_SFD;
|
||||
uint64_t seed;
|
||||
};
|
||||
|
||||
template<typename Gemv_>
|
||||
struct TestbedGemvFp4SFD : public TestbedGemvFp4SFDBase<
|
||||
Gemv_,
|
||||
typename Gemv_::ElementC,
|
||||
typename Gemv_::EpilogueOutputOp::LayoutOutput,
|
||||
typename Gemv_::EpilogueOutputOp::ElementD,
|
||||
typename Gemv_::EpilogueOutputOp::LayoutOutput,
|
||||
typename Gemv_::EpilogueOutputOp::ElementSFD,
|
||||
typename Gemv_::EpilogueOutputOp::LayoutSFD,
|
||||
typename Gemv_::EpilogueOutputOp::ElementCompute,
|
||||
Gemv_::EpilogueOutputOp::kVectorSize
|
||||
> {
|
||||
using Base = TestbedGemvFp4SFDBase<
|
||||
Gemv_,
|
||||
typename Gemv_::ElementC,
|
||||
typename Gemv_::EpilogueOutputOp::LayoutOutput,
|
||||
typename Gemv_::EpilogueOutputOp::ElementD,
|
||||
typename Gemv_::EpilogueOutputOp::LayoutOutput,
|
||||
typename Gemv_::EpilogueOutputOp::ElementSFD,
|
||||
typename Gemv_::EpilogueOutputOp::LayoutSFD,
|
||||
typename Gemv_::EpilogueOutputOp::ElementCompute,
|
||||
Gemv_::EpilogueOutputOp::kVectorSize
|
||||
>;
|
||||
|
||||
using Base::Base;
|
||||
using Gemv = Gemv_;
|
||||
using ElementCompute = typename Base::ElementCompute;
|
||||
using SfAtom_Input = typename Base::SfAtom_Input;
|
||||
using Blk_MN_Input = typename Base::Blk_MN_Input;
|
||||
using Blk_SF_Input = typename Base::Blk_SF_Input;
|
||||
|
||||
static constexpr int kVectorSize = Base::kVectorSize;
|
||||
|
||||
typename Gemv::Arguments get_arguments(
|
||||
cutlass::MatrixCoord problem_size,
|
||||
int32_t batch_count, float epilogue_st,
|
||||
ElementCompute alpha, ElementCompute beta) override {
|
||||
|
||||
const int32_t gemm_m = problem_size.row();
|
||||
const int32_t gemm_k = problem_size.column();
|
||||
[[maybe_unused]] const int32_t gemm_n = 1;
|
||||
[[maybe_unused]] const int32_t gemm_batch = batch_count;
|
||||
|
||||
auto k_blks_input = cutlass::ceil_div(gemm_k, cute::size<1>(shape(SfAtom_Input{})));
|
||||
auto m_blks_input = cutlass::ceil_div(gemm_m, Blk_MN_Input{});
|
||||
auto n_blks_input = cutlass::ceil_div(gemm_n, Blk_MN_Input{});
|
||||
|
||||
int batch_stride_SFA = m_blks_input * Blk_MN_Input{} * k_blks_input * Blk_SF_Input{};
|
||||
int batch_stride_SFB = n_blks_input * Blk_MN_Input{} * k_blks_input * Blk_SF_Input{};
|
||||
|
||||
// Use the same SFD layout generation as reference to get correct batch stride
|
||||
using ProblemShapeType = cute::Shape<int, int, int, int>;
|
||||
auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch};
|
||||
|
||||
// Generate the same layout as reference uses
|
||||
using Sm1xxBlockScaledOutputConfig = typename Base::Sm1xxBlockScaledOutputConfig;
|
||||
auto sfd_layout = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL);
|
||||
|
||||
// Calculate batch stride from the generated layout
|
||||
// Extract the batch stride from the 3rd dimension stride
|
||||
// The stride<2> gives us the stride for the batch dimension
|
||||
auto batch_stride_tuple = cute::stride<2>(sfd_layout); // This returns (_0, 8192)
|
||||
int batch_stride_SFD = static_cast<int>(cute::get<1>(batch_stride_tuple)); // Extract the 8192 part
|
||||
|
||||
// Initialize GEMV kernel
|
||||
typename Gemv::Arguments arguments{
|
||||
problem_size, // problem_size
|
||||
batch_count, // batch_count
|
||||
typename Gemv::EpilogueOutputOp::Params{
|
||||
this->tensor_D.device_ref(), // tensor_d
|
||||
this->tensor_SFD.device_data(), // scale_factor_d_ptr
|
||||
alpha, // alpha
|
||||
beta, // beta
|
||||
epilogue_st, // st
|
||||
batch_stride_SFD, // batch_stride_sfd
|
||||
gemm_m // stride_d
|
||||
},
|
||||
this->tensor_A.device_ref(), // ref_A
|
||||
this->tensor_B.device_data(), // ptr_B
|
||||
this->tensor_C.device_data(), // ptr_C
|
||||
this->tensor_D.device_data(), // ptr_D
|
||||
this->tensor_SFA.device_data(), // ptr_SFA
|
||||
this->tensor_SFB.device_data(), // ptr_SFB
|
||||
gemm_k, // stride_A
|
||||
gemm_m * gemm_k, // batch_stride_A
|
||||
gemm_k, // batch_stride_B
|
||||
gemm_m, // batch_stride_C
|
||||
gemm_m, // batch_stride_D
|
||||
batch_stride_SFA, // batch_stride_SFA
|
||||
batch_stride_SFB, // batch_stride_SFB
|
||||
batch_stride_SFD // batch_stride_SFD
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
};
|
||||
|
||||
struct Options {
|
||||
bool help = false;
|
||||
|
||||
int m = 4096;
|
||||
int k = 2048;
|
||||
int n = 1;
|
||||
int batch = 1;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float epilogue_st = -1.0f; // sentinel for random
|
||||
|
||||
bool profiling = true;
|
||||
int iterations = 10;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("batch", batch);
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("epilogue_st", epilogue_st);
|
||||
cmd.get_cmd_line_argument("profiling", profiling);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "91_fp4_gemv\n\n"
|
||||
<< " FP4 GEMV with block-scaled inputs and outputs.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --batch=<int> Sets the batch count of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --epilogue_st=<f32> Epilogue ST value\n\n"
|
||||
<< " --profiling=<bool> Whether to run profiling\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "91_fp4_gemv" << " --m=4096 --k=2048 --batch=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
bool
|
||||
run_fp4_gemv_device(Options const& options)
|
||||
{
|
||||
CUTLASS_ASSERT(options.n == 1);
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementSFA = cutlass::float_e4m3_t;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementSFB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementC = cutlass::float_e2m1_t;
|
||||
|
||||
using ElementD = cutlass::float_e2m1_t;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementSFD = cutlass::float_e4m3_t;
|
||||
// Indicate SF is computed along col dim. Does NOT indicate actual layout of SFD
|
||||
using LayoutSFD = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementAccumulatorMainloop = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha{options.alpha};
|
||||
ElementCompute beta{options.beta};
|
||||
// Must be a positive number.
|
||||
const float epilogue_st = options.epilogue_st < 0.f ?
|
||||
static_cast<float>(rand()) / (static_cast<float>(RAND_MAX / 5)) :
|
||||
options.epilogue_st;
|
||||
|
||||
static constexpr int kVectorSize = 16;
|
||||
static constexpr int kElementsPerAccess = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ThreadShape = cutlass::gemm::GemmShape<16, 8>;
|
||||
static_assert(kVectorSize == ThreadShape::kM, "vector size and thread in row should be equal");
|
||||
|
||||
// Construct Epilogue
|
||||
using EpilogueOp = typename cutlass::epilogue::threadblock::GemvEpilogueWithScalingFactor<kVectorSize,
|
||||
ThreadShape,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ElementC,
|
||||
ElementD,
|
||||
ElementSFD,
|
||||
LayoutD,
|
||||
LayoutSFD>;
|
||||
|
||||
// Construct Mainloop
|
||||
using Gemv = cutlass::gemm::device::GemvBlockScaled<
|
||||
cutlass::gemm::kernel::
|
||||
GemvBlockScaled<ElementA, LayoutA, ElementB, ElementD, ElementAccumulatorMainloop, EpilogueOp, kElementsPerAccess>>;
|
||||
|
||||
TestbedGemvFp4SFD<Gemv> testbed;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if (options.profiling) {
|
||||
pass = testbed.profile(cutlass::MatrixCoord{options.m, options.k}, options.batch, alpha, beta, epilogue_st, options.iterations);
|
||||
}
|
||||
else {
|
||||
pass = testbed.run_and_verify(cutlass::MatrixCoord{options.m, options.k}, options.batch, alpha, beta, epilogue_st);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int
|
||||
main(int argc, char const** argv)
|
||||
{
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Run verification
|
||||
Options verification_options = options;
|
||||
verification_options.profiling = false;
|
||||
|
||||
bool passed = run_fp4_gemv_device(verification_options);
|
||||
if (passed == false) {
|
||||
printf("test fail\n");
|
||||
return 1;
|
||||
} else {
|
||||
printf("test pass\n");
|
||||
}
|
||||
|
||||
|
||||
if (options.profiling) {
|
||||
// Start profiling
|
||||
printf("\nProfiling...\n");
|
||||
passed = run_fp4_gemv_device(options);
|
||||
if (passed == false) {
|
||||
printf("profiling fail\n");
|
||||
return 1;
|
||||
} else {
|
||||
printf("profiling completed\n");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return 0;
|
||||
#else
|
||||
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM100_SUPPORTED is defined.\n";
|
||||
return 0;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
}
|
||||
36
examples/91_fp4_gemv/CMakeLists.txt
Normal file
36
examples/91_fp4_gemv/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (NOT MSVC)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
91_fp4_gemv
|
||||
91_fp4_gemv.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -0,0 +1,701 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Example of Blackwell MoE-style grouped NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B.
|
||||
|
||||
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
|
||||
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
|
||||
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
|
||||
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
|
||||
|
||||
Usage:
|
||||
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped
|
||||
--m=28672 --n=4 --k=4096 --l=8 --benchmark=benchmark.txt
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool verification;
|
||||
|
||||
int m, n, k, l;
|
||||
|
||||
int iterations;
|
||||
|
||||
std::string benchmark_path;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
verification(true),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 10);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
verification = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "92_blackwell_moe_gemm_fp4_grouped\n\n"
|
||||
<< " Blackwell MoE-style grouped NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
|
||||
<< " --benchmark=<file> Executes a benchmark problem size\n"
|
||||
<< " --no_verif Do not run verification kernels\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class Element, class Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t> || cute::is_same_v<Element, cutlass::float_ue4m3_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto make_iterator(T* ptr) {
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleRunner {
|
||||
// Type of kernel schedule to generate
|
||||
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100;
|
||||
// Type of epilogue schedule to generate
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
static constexpr bool FuseQuantization = false;
|
||||
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor;
|
||||
using LayoutCTag = cutlass::layout::ColumnMajor;
|
||||
using LayoutDTag = cutlass::layout::ColumnMajor;
|
||||
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
|
||||
|
||||
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
|
||||
using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands
|
||||
|
||||
using ElementA = cutlass::nv_float4_t<ElementInput>;
|
||||
using ElementB = cutlass::nv_float4_t<ElementInput>;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cute::conditional_t<FuseQuantization, ElementInput, ElementC>;
|
||||
using ElementSFD = ElementSF;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
static constexpr int OutputSFVectorSize = 16;
|
||||
|
||||
// D = alpha * acc + beta * C
|
||||
// With BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, LayoutSFDTag,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cute::conditional_t<
|
||||
FuseQuantization,
|
||||
FusionOperation,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShapeGroup = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ProblemShapeMax = Shape<int,int,int,int>; // max <M,N,K,L>
|
||||
using ProblemShape = cutlass::gemm::MoEProblemShape<ProblemShapeGroup, ProblemShapeMax>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutSFD layout_SFD;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShapeGroup::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(ProblemShape const& problem_size, float alpha, float beta) {
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
// think about how to simplify the gemm3x interface.
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
|
||||
if constexpr (FuseQuantization) {
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D), // TensorD
|
||||
decltype(tensor_SFD), // TensorSfD
|
||||
cute::Int<OutputSFVectorSize>,
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
> epilogue_params {alpha, beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
}
|
||||
else {
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params {alpha, beta, tensor_C, tensor_D };
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
bool passed = true;
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
|
||||
auto [maxM, maxN, maxK, L] = problem_size.max_problem_shape;
|
||||
for (int i = 0; i < problem_size.problem_shape.num_groups; i++) {
|
||||
auto problem = problem_size.problem_shape.get_host_problem_shape(i);
|
||||
auto [M, N, K] = problem;
|
||||
|
||||
// assume all M == maxM
|
||||
auto refD_view = block_reference_D.host_view().subview(cutlass::make_Coord(M * N), cutlass::make_Coord(i * maxN * maxM));
|
||||
auto D_view = block_D.host_view().subview(cutlass::make_Coord(M * N), cutlass::make_Coord(i * maxN * maxM));
|
||||
|
||||
passed &= cutlass::reference::host::TensorEquals(refD_view, D_view);
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(ProblemShape const& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size.max_problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm1xxBlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
// printf("\nStrideC = ");
|
||||
// print(StrideC{});
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, L});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L});
|
||||
|
||||
// printf("\nstride_C = ");
|
||||
// print(stride_C);
|
||||
|
||||
layout_A = make_layout(make_shape(M, K, L), stride_A);
|
||||
layout_B = make_layout(make_shape(N, K, L), stride_B);
|
||||
layout_C = make_layout(make_shape(M, N, L), stride_C);
|
||||
layout_D = make_layout(make_shape(M, N, L), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, L));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, L));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, L));
|
||||
|
||||
// printf("\nlayout_A = ");
|
||||
// print(layout_A);
|
||||
// printf("\nlayout_B = ");
|
||||
// print(layout_B);
|
||||
// printf("\nlayout_C = ");
|
||||
// print(layout_C);
|
||||
|
||||
// printf("\nsize(layout_A)=%lld", (long long)size(layout_A));
|
||||
// printf("\n");
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_Normconst.reset(cutlass::make_Coord(1));
|
||||
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_Normconst.at(cutlass::make_Coord(0)) = 2;
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_D.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
block_SFD.sync_device();
|
||||
block_Normconst.sync_device();
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
std::vector<ProblemShapeGroup::UnderlyingProblemShape> benchmark_problems(std::string const& benchmark_path) {
|
||||
std::vector<ProblemShapeGroup::UnderlyingProblemShape> problem_sizes_host;
|
||||
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
extent.at(i) = std::atoi(tokens.at(i).c_str());
|
||||
}
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
|
||||
return problem_sizes_host;
|
||||
}
|
||||
|
||||
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto problem_sizes_host = benchmark_problems(options.benchmark_path);
|
||||
if (problem_sizes_host.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
problem_sizes.reset(problem_sizes_host.size());
|
||||
problem_sizes.copy_from_host(problem_sizes_host.data());
|
||||
|
||||
ProblemShape problem_size;
|
||||
problem_size.max_problem_shape = ProblemShapeMax{options.m, options.n, options.k, options.l};
|
||||
problem_size.problem_shape.num_groups = problem_sizes_host.size();
|
||||
problem_size.problem_shape.problem_shapes = problem_sizes.get();
|
||||
problem_size.problem_shape.host_problem_shapes = problem_sizes_host.data();
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
problem_size,
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
|
||||
auto f = [&](auto blockscale) {
|
||||
auto impl = [this](auto& arguments) {
|
||||
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
};
|
||||
if constexpr (decltype(blockscale)::value) {
|
||||
impl(arguments);
|
||||
}
|
||||
};
|
||||
f(std::bool_constant<IsBlockScaleSupported>());
|
||||
|
||||
// arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
arguments.epilogue.thread.alpha = 1.0f;
|
||||
arguments.epilogue.thread.beta = 0.0f;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (options.verification) {
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, 1.0f, 0.0f);
|
||||
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!passed) {
|
||||
exit(-1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_op.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and FLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
ExampleRunner runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -0,0 +1,654 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Example of Blackwell MoE-style NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B
|
||||
|
||||
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
|
||||
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
|
||||
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
|
||||
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
|
||||
|
||||
This example assumes all experts have the same number of tokens, in which case the GEMM becomes a regular (batched) gemm.
|
||||
For the realistic use case where each expert may have different number of tokens (grouped GEMM), check 92_blackwell_moe_gemm_fp4_grouped.
|
||||
|
||||
Usage:
|
||||
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular
|
||||
--m=28672 --n=4 --k=4096 --l=8
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool verification;
|
||||
|
||||
int m, n, k, l;
|
||||
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
verification(true),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 10);
|
||||
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
verification = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "92_blackwell_moe_gemm_fp4_regular\n\n"
|
||||
<< " Blackwell NVFP4 GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
|
||||
<< " --no_verif Do not run verification kernels\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class Element, class Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t> || cute::is_same_v<Element, cutlass::float_ue4m3_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto make_iterator(T* ptr) {
|
||||
return cute::recast_ptr<T>(ptr);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MSVC complain about it if moved to ExampleRunner
|
||||
static constexpr int OutputSFVectorSize = 16;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
|
||||
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
|
||||
// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the
|
||||
// number of pipeline stages.
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
bool FuseQuantization = false
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor;
|
||||
using LayoutCTag = cutlass::layout::ColumnMajor;
|
||||
using LayoutDTag = cutlass::layout::ColumnMajor;
|
||||
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
|
||||
|
||||
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
|
||||
using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands
|
||||
|
||||
using ElementA = cutlass::nv_float4_t<ElementInput>;
|
||||
using ElementB = cutlass::nv_float4_t<ElementInput>;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cute::conditional_t<FuseQuantization, ElementInput, ElementC>;
|
||||
using ElementSFD = ElementSF;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
|
||||
// D = alpha * acc + beta * C
|
||||
// With BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, LayoutSFDTag,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cute::conditional_t<
|
||||
FuseQuantization,
|
||||
FusionOperation,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutSFD layout_SFD;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(ProblemShapeType const& problem_size, float alpha, float beta) {
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
// think about how to simplify the gemm3x interface.
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
|
||||
if constexpr (FuseQuantization) {
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D), // TensorD
|
||||
decltype(tensor_SFD), // TensorSfD
|
||||
cute::Int<OutputSFVectorSize>,
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
> epilogue_params {alpha, beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
}
|
||||
else {
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params {alpha, beta, tensor_C, tensor_D };
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
bool passed = true, passed_sfd = true;
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
|
||||
if constexpr (FuseQuantization) {
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
block_SFD.sync_host();
|
||||
passed_sfd &= cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view());
|
||||
passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0);
|
||||
passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0);
|
||||
}
|
||||
|
||||
// printf("passed=%d\n", (int)passed);
|
||||
// printf("passed_sfd=%d\n", (int)passed_sfd);
|
||||
return passed && passed_sfd;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(ProblemShapeType const& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm1xxBlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
// printf("\nStrideC = ");
|
||||
// print(StrideC{});
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, L});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L});
|
||||
|
||||
// printf("\nstride_C = ");
|
||||
// print(stride_C);
|
||||
|
||||
layout_A = make_layout(make_shape(M, K, L), stride_A);
|
||||
layout_B = make_layout(make_shape(N, K, L), stride_B);
|
||||
layout_C = make_layout(make_shape(M, N, L), stride_C);
|
||||
layout_D = make_layout(make_shape(M, N, L), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, L));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, L));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, L));
|
||||
|
||||
// printf("\nlayout_A = ");
|
||||
// print(layout_A);
|
||||
// printf("\nlayout_B = ");
|
||||
// print(layout_B);
|
||||
// printf("\nlayout_C = ");
|
||||
// print(layout_C);
|
||||
|
||||
// printf("\nsize(layout_A)=%lld", (long long)size(layout_A));
|
||||
// printf("\n");
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_Normconst.reset(cutlass::make_Coord(1));
|
||||
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_Normconst.at(cutlass::make_Coord(0)) = 2;
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_D.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
block_SFD.sync_device();
|
||||
block_Normconst.sync_device();
|
||||
}
|
||||
|
||||
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
|
||||
if constexpr (IsBlockScaleSupported) {
|
||||
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
}
|
||||
|
||||
// arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
arguments.epilogue.thread.alpha = 1.0f;
|
||||
arguments.epilogue.thread.beta = 0.0f;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (options.verification) {
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, 1.0f, 0.0f);
|
||||
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!passed) {
|
||||
exit(-1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_op.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and FLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
std::cout << "Running kernel with TMA load:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100> runner_tma;
|
||||
runner_tma.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100> runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
541
examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu
Normal file
541
examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu
Normal file
@ -0,0 +1,541 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Example of Blackwell MoE-style grouped GEMM implementation using TMA to load A and CPASYNC to load B.
|
||||
|
||||
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
|
||||
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
|
||||
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
|
||||
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
|
||||
|
||||
Usage:
|
||||
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped
|
||||
--m=28672 --n=4 --k=4096 --l=8 --benchmark=benchmark.txt
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool verification;
|
||||
|
||||
int m, n, k, l;
|
||||
|
||||
int iterations;
|
||||
|
||||
std::string benchmark_path;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
verification(true),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 10);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
verification = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "92_blackwell_moe_gemm_grouped\n\n"
|
||||
<< " Blackwell MoE-style grouped GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
|
||||
<< " --benchmark=<file> Executes a benchmark problem size\n"
|
||||
<< " --no_verif Do not run verification kernels\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
}
|
||||
else {
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleRunner {
|
||||
|
||||
// Type of kernel schedule to generate
|
||||
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100;
|
||||
// Type of epilogue schedule to generate
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShapeGroup = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ProblemShapeMax = Shape<int,int,int,int>; // max <M,N,K,L>
|
||||
using ProblemShape = cutlass::gemm::MoEProblemShape<ProblemShapeGroup, ProblemShapeMax>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
//, cutlass::gemm::MoEScheduler
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
|
||||
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
|
||||
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
|
||||
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShapeGroup::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(ProblemShape const& problem_size, float alpha, float beta) {
|
||||
auto [maxM, maxN, maxK, L] = problem_size.max_problem_shape;
|
||||
for (int i = 0; i < problem_size.problem_shape.num_groups; i++) {
|
||||
auto problem = problem_size.problem_shape.get_host_problem_shape(i);
|
||||
auto [M, N, K] = problem;
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, Gemm::LayoutA(maxK));
|
||||
cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, Gemm::LayoutB(maxK));
|
||||
cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutC(maxM));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutD(maxM));
|
||||
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementScalar,
|
||||
ElementAccumulator>;
|
||||
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
gemm_reference(
|
||||
{M, N, K},
|
||||
ElementScalar(alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementScalar(beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
// assume all M == maxM
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + size_t(1) * i * maxN * maxM, block_D.get() + size_t(1) * i * maxN * maxM, M * N);
|
||||
if (!passed) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(ProblemShape const& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size.max_problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(size_t(1) * M * K * L);
|
||||
block_B.reset(size_t(1) * K * N * L);
|
||||
block_C.reset(size_t(1) * M * N * L);
|
||||
block_D.reset(size_t(1) * M * N * L);
|
||||
block_ref_D.reset(size_t(1) * M * N * L);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
std::vector<ProblemShapeGroup::UnderlyingProblemShape> benchmark_problems(std::string const& benchmark_path) {
|
||||
std::vector<ProblemShapeGroup::UnderlyingProblemShape> problem_sizes_host;
|
||||
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
extent.at(i) = std::atoi(tokens.at(i).c_str());
|
||||
}
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
|
||||
return problem_sizes_host;
|
||||
}
|
||||
|
||||
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto problem_sizes_host = benchmark_problems(options.benchmark_path);
|
||||
if (problem_sizes_host.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
problem_sizes.reset(problem_sizes_host.size());
|
||||
problem_sizes.copy_from_host(problem_sizes_host.data());
|
||||
|
||||
ProblemShape problem_size;
|
||||
problem_size.max_problem_shape = ProblemShapeMax{options.m, options.n, options.k, options.l};
|
||||
problem_size.problem_shape.num_groups = problem_sizes_host.size();
|
||||
problem_size.problem_shape.problem_shapes = problem_sizes.get();
|
||||
problem_size.problem_shape.host_problem_shapes = problem_sizes_host.data();
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
problem_size,
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{}, // epilogue.thread
|
||||
block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
// arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
arguments.epilogue.thread.alpha = 1.0f;
|
||||
arguments.epilogue.thread.beta = 0.0f;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (options.verification) {
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, 1.0f, 0.0f);
|
||||
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!passed) {
|
||||
exit(-1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_op.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and FLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
ExampleRunner runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
484
examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu
Normal file
484
examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu
Normal file
@ -0,0 +1,484 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Example of Blackwell MoE-style GEMM implementation using TMA to load A and CPASYNC to load B.
|
||||
|
||||
This example demonstrates an implementation of GEMM using mixed TMA+CPASYNC to load input matrices.
|
||||
In the decoding stage of Mixture of Experts (MoE) models, the number of tokens in different experts
|
||||
can varies a lot, which requires frequently updates of TMA descriptors in TMA-based implementation.
|
||||
This examples uses CPASYNC to load activation (B) matrix to avoid the overhead of updating TMA descriptors.
|
||||
|
||||
Usage:
|
||||
$ ./examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular
|
||||
--m=28672 --n=4 --k=4096 --l=8
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool verification;
|
||||
|
||||
int m, n, k, l;
|
||||
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
verification(true),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 10);
|
||||
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
verification = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "92_blackwell_moe_gemm_regular\n\n"
|
||||
<< " Blackwell GEMM implementation using TMA to load A and CPASYNC to load B\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --iterations=<int> Set the number of profiling iterations to perform\n"
|
||||
<< " --no_verif Do not run verification kernels\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
}
|
||||
else {
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
|
||||
// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the
|
||||
// number of pipeline stages.
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
|
||||
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
|
||||
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
|
||||
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(ProblemShapeType const& problem_size, float alpha, float beta) {
|
||||
auto [M, N, K, L] = problem_size;
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
cutlass::reference::device::GemmComplex(
|
||||
{M, N, K},
|
||||
ElementScalar(alpha),
|
||||
ref_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ref_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementScalar(beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
ElementAccumulator(0),
|
||||
L, // batch_count
|
||||
M * K, // batch_stride_A
|
||||
K * N, // batch_stride_B
|
||||
M * N, // batch_stride_C
|
||||
M * N // batch_stride_D
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(ProblemShapeType const& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(size_t(1) * M * K * L);
|
||||
block_B.reset(size_t(1) * K * N * L);
|
||||
block_C.reset(size_t(1) * M * N * L);
|
||||
block_D.reset(size_t(1) * M * N * L);
|
||||
block_ref_D.reset(size_t(1) * M * N * L);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
bool run(Options const& options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{}, // epilogue.thread
|
||||
block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
// arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
arguments.epilogue.thread.alpha = 1.0f;
|
||||
arguments.epilogue.thread.beta = 0.0f;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (options.verification) {
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, 1.0f, 0.0f);
|
||||
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!passed) {
|
||||
exit(-1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_op.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and FLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double flops = double(int64_t(2) * options.m * options.n * options.k * options.l) / (avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime : " << avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " TFLOPS : " << flops / 1e12 << std::endl;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
std::cout << "Running kernel with TMA load:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmSm100> runner_tma;
|
||||
runner_tma.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with CPASYNC load:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelWarpSpecialized1SmSm100> runner_cpasync;
|
||||
runner_cpasync.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100> runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
70
examples/92_blackwell_moe_gemm/CMakeLists.txt
Normal file
70
examples/92_blackwell_moe_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(TEST_MIXTRAL_A --m=28672 --n=4 --k=4096 --l=8)
|
||||
set(TEST_MIXTRAL_B --m=4096 --n=4 --k=14336 --l=8)
|
||||
set(TEST_DEEPSEEK_A --m=4096 --n=1 --k=7168 --l=256)
|
||||
set(TEST_DEEPSEEK_B --m=7168 --n=1 --k=2048 --l=256)
|
||||
set(TEST_IRREGULAR_MNK --m=4080 --n=9 --k=4112 --l=8) # M,N,K not multiples of tile size
|
||||
|
||||
set(TEST_DEEPSEEK_A_FP4 --m=1024 --n=1 --k=7168 --l=256) # TP=1 shape is too large for PackedVectorLayout
|
||||
set(TEST_DEEPSEEK_B_FP4 --m=7168 --n=1 --k=512 --l=256)
|
||||
set(TEST_IRREGULAR_MNK_FP4 --m=4080 --n=9 --k=4160 --l=8)
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
92_blackwell_moe_gemm_regular
|
||||
92_blackwell_moe_gemm_regular.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MIXTRAL_A
|
||||
TEST_MIXTRAL_B
|
||||
TEST_DEEPSEEK_A
|
||||
TEST_DEEPSEEK_B
|
||||
TEST_IRREGULAR_MNK
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
92_blackwell_moe_gemm_grouped
|
||||
92_blackwell_moe_gemm_grouped.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
92_blackwell_moe_gemm_fp4_regular
|
||||
92_blackwell_moe_gemm_fp4_regular.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MIXTRAL_A
|
||||
TEST_MIXTRAL_B
|
||||
TEST_DEEPSEEK_A_FP4
|
||||
TEST_DEEPSEEK_B_FP4
|
||||
TEST_IRREGULAR_MNK_FP4
|
||||
)
|
||||
cutlass_example_add_executable(
|
||||
92_blackwell_moe_gemm_fp4_grouped
|
||||
92_blackwell_moe_gemm_fp4_grouped.cu
|
||||
)
|
||||
endif()
|
||||
@ -163,7 +163,13 @@ foreach(EXAMPLE
|
||||
82_blackwell_distributed_gemm
|
||||
83_blackwell_sparse_gemm
|
||||
84_blackwell_narrow_precision_sparse_gemm
|
||||
86_blackwell_mixed_dtype_gemm
|
||||
87_blackwell_geforce_gemm_blockwise
|
||||
88_hopper_fmha
|
||||
89_sm103_fp4_ultra_gemm
|
||||
90_sm103_fp4_ultra_grouped_gemm
|
||||
91_fp4_gemv
|
||||
92_blackwell_moe_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user