Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b995f93317 | |||
| 889ff20648 | |||
| dc4817921e | |||
| 5c6bca0441 | |||
| c2ad7c5b20 | |||
| cc23f6d1e9 | |||
| 5a287538c2 | |||
| 8bdbfca682 | |||
| 2e2af190bd | |||
| f12b1d75c9 | |||
| b244379d9b | |||
| 9d165a3b8e | |||
| b9b110a9ea | |||
| 8206e7a0f5 | |||
| 6316b6f867 | |||
| 9354bfd7c1 | |||
| 5e9b8e2a25 | |||
| 1ec230c4bf | |||
| f89cd95b16 | |||
| f115c3f854 |
112
.github/workflows/blossom-ci.yml
vendored
Normal file
112
.github/workflows/blossom-ci.yml
vendored
Normal file
@ -0,0 +1,112 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
# A workflow to trigger ci on hybrid infra (github + self hosted runner)
|
||||
name: Blossom-CI
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
platform:
|
||||
description: 'runs-on argument'
|
||||
required: false
|
||||
args:
|
||||
description: 'argument'
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
Authorization:
|
||||
name: Authorization
|
||||
runs-on: blossom
|
||||
outputs:
|
||||
args: ${{ env.args }}
|
||||
|
||||
# This job only runs for pull request comments
|
||||
if: |
|
||||
(startsWith(github.event.comment.body, '/bot run') ||
|
||||
startsWith(github.event.comment.body, '/bot kill')) && contains(
|
||||
fromJson('["zekunf-nv"]'),
|
||||
github.actor)
|
||||
steps:
|
||||
- name: Check if comment is issued by authorized person
|
||||
run: blossom-ci
|
||||
env:
|
||||
OPERATION: 'AUTH'
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
|
||||
|
||||
Vulnerability-scan:
|
||||
name: Vulnerability scan
|
||||
needs: [Authorization]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
repository: ${{ fromJson(needs.Authorization.outputs.args).repo }}
|
||||
ref: ${{ fromJson(needs.Authorization.outputs.args).ref }}
|
||||
lfs: 'true'
|
||||
|
||||
- name: Run blossom action
|
||||
uses: NVIDIA/blossom-action@main
|
||||
env:
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
|
||||
with:
|
||||
args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }}
|
||||
args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }}
|
||||
args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }}
|
||||
|
||||
Job-trigger:
|
||||
name: Start ci job
|
||||
needs: [Vulnerability-scan]
|
||||
runs-on: blossom
|
||||
steps:
|
||||
- name: Start ci job
|
||||
run: blossom-ci
|
||||
env:
|
||||
OPERATION: 'START-CI-JOB'
|
||||
CI_SERVER: ${{ secrets.CI_SERVER }}
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
Upload-Log:
|
||||
name: Upload log
|
||||
runs-on: blossom
|
||||
if : github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- name: Jenkins log for pull request ${{ fromJson(github.event.inputs.args).pr }} (click here)
|
||||
run: blossom-ci
|
||||
env:
|
||||
OPERATION: 'POST-PROCESSING'
|
||||
CI_SERVER: ${{ secrets.CI_SERVER }}
|
||||
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
529
CHANGELOG.md
529
CHANGELOG.md
@ -1,56 +1,115 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
# Changelog
|
||||
|
||||
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
|
||||
# CUTLASS 4.x
|
||||
|
||||
* Fixed [Blockwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM hang issue when problem size K is 128.
|
||||
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
|
||||
|
||||
### CuTe DSL
|
||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
|
||||
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
|
||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||
* API updates
|
||||
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||
|
||||
### CUTLASS C++
|
||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
|
||||
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
||||
- For example:
|
||||
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
|
||||
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
|
||||
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Support LSE output in FMHA Forward kernel.
|
||||
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
|
||||
- Enhance testing of variable sequence length.
|
||||
- Disable B2B mode in MLA to simplify the sample.
|
||||
- Clarify that `fmha_gen` sample only supports head dim 128.
|
||||
- Fixes for split-kv output in MLA.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
|
||||
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
|
||||
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
|
||||
* Fix profiler issues which cause no output or not supported error for some kernels.
|
||||
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
|
||||
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
|
||||
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* CuTe changes:
|
||||
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
|
||||
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
|
||||
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
|
||||
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
|
||||
## [3.9.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.1) (2025-04-30)
|
||||
# CUTLASS 3.x
|
||||
|
||||
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
|
||||
* Fixed [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM hang issue when problem size K is 128.
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
## [3.9.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.1) (2025-04-30)
|
||||
* Fixed Group Gemm hang issue in CUTLASS 3.x
|
||||
* Improved Hopper [Blockwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM performance.
|
||||
* Improved Hopper [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMM performance.
|
||||
|
||||
## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-04-24)
|
||||
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* [Blockscaled datatypes with support for dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](https://github.com/NVIDIA/cutlass/tree/main/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](https://github.com/NVIDIA/cutlass/tree/main/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* [SM100 Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
- [Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Enhancement of [blockwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](https://github.com/NVIDIA/cutlass/tree/main/tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
||||
@ -58,32 +117,32 @@
|
||||
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Various narrow precision [FP4, FP6, and FP8](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp), [convolution](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/dispatch_policy.hpp), and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and full set of EVT fusions.
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
@ -91,131 +150,131 @@
|
||||
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
|
||||
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](https://github.com/NVIDIA/cutlass/tree/main/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
+ [NVFP4 inputs with BF16 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](https://github.com/NVIDIA/cutlass/tree/main/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](https://github.com/NVIDIA/cutlass/tree/main/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](https://github.com/NVIDIA/cutlass/tree/main/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](https://github.com/NVIDIA/cutlass/tree/main/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
|
||||
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
- Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper.
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
|
||||
- A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
|
||||
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
- [Distributed GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
- Potential API breaking changes:
|
||||
+ Fix `cute::UniversalCopy` for type safety.
|
||||
+ No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom.
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler).
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03)
|
||||
|
||||
- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
|
||||
+ [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
|
||||
+ [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
|
||||
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
|
||||
- [Hopper structured sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
|
||||
+ [FP16](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
|
||||
+ [FP8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
|
||||
+ [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
|
||||
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25)
|
||||
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#GEMM), and
|
||||
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
+ [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
|
||||
+ [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
|
||||
- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
|
||||
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py).
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#gemm), and
|
||||
[example 48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
+ [FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
|
||||
+ [int8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [int4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [FP32 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
|
||||
- [CUDA host adapter](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
|
||||
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py).
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md).
|
||||
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
|
||||
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
||||
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
|
||||
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
||||
- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
|
||||
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
|
||||
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
|
||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
+ [Ampere FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](https://github.com/NVIDIA/cutlass/tree/main/include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
|
||||
- Fixes to greatly reduce build warnings.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
|
||||
|
||||
- Statically available [CUTLASS Version macros](./include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](./examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Statically available [CUTLASS Version macros](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
|
||||
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
|
||||
* Expanded [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](./media/docs/cpp/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
* Expanded [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved CuTe documentation including improved clarity and depth of [Quickstart](./media/docs/cpp/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
|
||||
* [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
|
||||
* [Copy Async based Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
|
||||
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
|
||||
* Profiler support for lower-aligned Hopper GEMMs.
|
||||
* Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion).
|
||||
* Performance Improvements to [Scatter-Gather Hopper Example](https://github.com/NVIDIA/cutlass/tree/main/examples/52_hopper_gather_scatter_fusion).
|
||||
* Sub-Byte type fixes and improvements.
|
||||
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](./include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
|
||||
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
|
||||
|
||||
@ -227,7 +286,7 @@
|
||||
* SM80 EVT support in C++ and Python.
|
||||
* Other SM90 epilogue improvements.
|
||||
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](./python/README.md) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) for details.
|
||||
* SM90 TF32 kernel improvements for all layouts.
|
||||
* SM90 rasterization direction support in the CUTLASS profiler.
|
||||
* Improvement for CUTLASS profiler build times.
|
||||
@ -235,34 +294,34 @@
|
||||
|
||||
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
|
||||
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](./examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](./include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](./include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* [Hopper GEMM+Permute](./examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](./examples/python/03_basic_conv2d.ipynb) here.
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* [Hopper GEMM+Permute](https://github.com/NVIDIA/cutlass/tree/main/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](https://github.com/NVIDIA/cutlass/tree/main/examples/python/03_basic_conv2d.ipynb) here.
|
||||
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
|
||||
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](./python/README.md) and new [examples](./examples/python).
|
||||
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](./examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python).
|
||||
* New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](https://github.com/NVIDIA/cutlass/tree/main/examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](./examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](./include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* [FMHA Backward Pass](https://github.com/NVIDIA/cutlass/tree/main/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](https://github.com/NVIDIA/cutlass/tree/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* The GitHub branch is renamed from `master` to `main` in this release.
|
||||
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
@ -272,28 +331,30 @@
|
||||
* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#target-architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
* [CUTLASS library integration](./tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](./examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](./examples/48_hopper_warp_specialized_gemm), [49](./examples/49_hopper_gemm_schedules_with_collective_builder), and [50](./examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
* [CUTLASS library integration](https://github.com/NVIDIA/cutlass/tree/main/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](https://github.com/NVIDIA/cutlass/tree/main/examples/48_hopper_warp_specialized_gemm), [49](https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_schedules_with_collective_builder), and [50](https://github.com/NVIDIA/cutlass/tree/main/examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* [Stream-K](./examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](./examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](./examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](./test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](./test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* [Stream-K](https://github.com/NVIDIA/cutlass/tree/main/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](https://github.com/NVIDIA/cutlass/tree/main/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](https://github.com/NVIDIA/cutlass/tree/main/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
|
||||
* [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* [kFixedStrideDilation](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* [Scripts](https://github.com/NVIDIA/cutlass/tree/main/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/float8.h) and [conversion routines](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
@ -302,54 +363,54 @@
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* [CUTLASS Python](https://github.com/NVIDIA/cutlass/tree/main/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](https://github.com/NVIDIA/cutlass/tree/main/examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](https://github.com/NVIDIA/cutlass/tree/main/examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](https://github.com/NVIDIA/cutlass/tree/main/examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](https://github.com/NVIDIA/cutlass/tree/main/examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
|
||||
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](./tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](./tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* [Depthwise separable convolution](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](./test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](./test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](./python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [First layer Convolution kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
|
||||
* [HERK](./test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
|
||||
* [SYRK](./test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
|
||||
* [SYMM](./test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/symm_operation.py)
|
||||
* [TRMM](./test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/trmm_operation.py)
|
||||
* [Unit tests](./test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](./examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](./tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](./examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](./examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* [HERK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/rank_k_operation.py)
|
||||
* [SYRK](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/rank_k_operation.py)
|
||||
* [SYMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/symm_operation.py)
|
||||
* [TRMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/trmm_operation.py)
|
||||
* [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](https://github.com/NVIDIA/cutlass/tree/main/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](https://github.com/NVIDIA/cutlass/tree/main/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](https://github.com/NVIDIA/cutlass/tree/main/examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* It can select random rows in a row major matrix.
|
||||
* It can select random columns in a column major matrix.
|
||||
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* [Back-to-back GEMM/CONV](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* Supported kernels: GEMM and CONV.
|
||||
* Supported types: fp16 and int8.
|
||||
* Supported architectures: Turing and Ampere.
|
||||
* [Transposed Convolution](./examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](./tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Transposed Convolution](https://github.com/NVIDIA/cutlass/tree/main/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](https://github.com/NVIDIA/cutlass/tree/main/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
|
||||
* Epilogue enhancement:
|
||||
* Eliminate bank conflicts in int8 tensor core kernels.
|
||||
* Half2 usage if epilogue compute type is fp16.
|
||||
* More activation functions: Silu, Hardswish, Leaky Relu.
|
||||
* New elementwise fusion pattern for [residual block](./include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](./examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* New elementwise fusion pattern for [residual block](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
@ -359,17 +420,17 @@
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
* 45+ TFLOPs on NVIDIA A100
|
||||
* [GEMM SDK example](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* [GEMM SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
|
||||
* [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* [Conv Fprop SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
|
||||
* [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](./include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](./examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* [SDK example](https://github.com/NVIDIA/cutlass/tree/main/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates from the community (thanks!)
|
||||
|
||||
@ -379,13 +440,13 @@
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](./examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](./include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](./include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](https://github.com/NVIDIA/cutlass/tree/main/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](./examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/cache_testbed_output.h)
|
||||
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
@ -398,24 +459,24 @@
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](./include/cutlass/arch/memory.h) and [global load](./include/cutlass/arch/memory_sm80.h)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/memory.h) and [global load](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/memory_sm80.h)
|
||||
* Fused operators with GEMM and Convolution
|
||||
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused broadcast in epilogue](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* 64b tensor strides and leading dimensions support for GEMMs
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](https://github.com/NVIDIA/cutlass/tree/main/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Accelerates over previous implementation by cutting down redundant math by 4x
|
||||
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
|
||||
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
|
||||
* Updates to [quaternion.h](./include/cutlass/quaternion.h) and [functional.h](./include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](./examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](./examples/22_quaternion_conv/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](./test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](./test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Updates to [quaternion.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/quaternion.h) and [functional.h](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](https://github.com/NVIDIA/cutlass/tree/main/examples/22_quaternion_conv/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Provide an [option](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
@ -427,14 +488,14 @@
|
||||
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
|
||||
* Tensor reductions
|
||||
* _m_-to-_n_ reductions of tensors with affine layout
|
||||
* [Specializations](./test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](./test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* [Specializations](https://github.com/NVIDIA/cutlass/tree/main/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](https://github.com/NVIDIA/cutlass/tree/main/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* Custom reduction functors such as `cutlass::logical_and`
|
||||
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
|
||||
* Optimizations for 3-D convolution
|
||||
* [Optimized tile iterators](./include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
|
||||
* [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md)
|
||||
* [Optimized tile iterators](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* Full coverage of [forward](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
|
||||
* [Fused Convolution+Convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/13_two_tensor_op_fusion/README.md)
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
@ -453,16 +514,16 @@
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* [Sparse Tensor Core GEMM kernels](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
|
||||
* [Sparse Tensor Core GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
|
||||
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
|
||||
* Minor Features:
|
||||
* [Activation functions](./include/cutlass/epilogue/thread/activation.h) such as [GeLU](./include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](./include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](./include/cutlass/constants.h)
|
||||
* [Activation functions](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/activation.h) such as [GeLU](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/matrix.h) and [quaternion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/constants.h)
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* [Tensor Float 32](https://github.com/NVIDIA/cutlass/tree/main/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](https://github.com/NVIDIA/cutlass/tree/main/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
@ -487,7 +548,7 @@
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
* [SDK Examples of Planar Complex GEMMs](./examples/10_planar_complex/planar_complex.cu)
|
||||
* [SDK Examples of Planar Complex GEMMs](https://github.com/NVIDIA/cutlass/tree/main/examples/10_planar_complex/planar_complex.cu)
|
||||
* Minor enhancements and bug fixes
|
||||
|
||||
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
|
||||
|
||||
@ -175,7 +175,13 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
@ -340,6 +346,10 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
|
||||
|
||||
#
|
||||
@ -676,25 +686,6 @@ if (NOT CUTLASS_NAMESPACE STREQUAL "cutlass")
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE})
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CUTLASS_REVISION)
|
||||
|
||||
find_package(Git QUIET)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
|
||||
RESULT_VARIABLE CUTLASS_REVISION_RESULT
|
||||
OUTPUT_VARIABLE CUTLASS_REVISION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if (CUTLASS_REVISION_RESULT)
|
||||
message(STATUS "CUTLASS Revision: Unable to detect, Git returned code ${CUTLASS_REVISION_RESULT}.")
|
||||
else()
|
||||
message(STATUS "CUTLASS Revision: ${CUTLASS_REVISION}")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
[README](./README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS Developers **
|
||||
# CUTLASS C++ Developers **
|
||||
|
||||
Andrew Kerr<br />
|
||||
Paul Springer<br />
|
||||
@ -70,8 +70,49 @@ Shreya Gaur<br />
|
||||
|
||||
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
||||
|
||||
# CUTLASS DSL Developers ***
|
||||
|
||||
# CUTE Developers
|
||||
Albert Di<br />
|
||||
Albert Xu<br />
|
||||
Anakin Zheng<br />
|
||||
Arvin Jou<br />
|
||||
Brandon Sun<br />
|
||||
Chenyang Xu<br />
|
||||
Chunyu Wang<br />
|
||||
Cris Cecka<br />
|
||||
dePaul Miller<br />
|
||||
Edward Cao<br />
|
||||
Fung Xie<br />
|
||||
Guray Ozen<br />
|
||||
Hao Hu<br />
|
||||
Hong Wang<br />
|
||||
Jeremy Furtek<br />
|
||||
Jie Fang <br />
|
||||
JingZe Cui<br />
|
||||
Kihiro Bando<br />
|
||||
Linfeng Zheng<br />
|
||||
Longsheng Du<br />
|
||||
Mina Sun<br />
|
||||
Mindy Li<br />
|
||||
Pradeep Ramani<br />
|
||||
Questa Wang<br />
|
||||
Serif Yesil<br />
|
||||
Tao Xie<br />
|
||||
Tina Li<br />
|
||||
Vicki Wang<br />
|
||||
Vincent Zhang<br />
|
||||
Vijay Thakkar<br />
|
||||
Xiao Dong<br />
|
||||
Xiaolei Shi<br />
|
||||
Xinyu Wang<br />
|
||||
Yihan Chen<br />
|
||||
Yuhan Li<br />
|
||||
Zekun Fan<br />
|
||||
|
||||
*** _Sorted in alphabetical order._
|
||||
|
||||
|
||||
# CuTe Developers
|
||||
|
||||
Cris Cecka<br />
|
||||
Vijay Thakkar<br />
|
||||
@ -100,6 +141,9 @@ David Tanner<br />
|
||||
|
||||
Tri Dao<br />
|
||||
Jay Shah<br />
|
||||
Mehdi Amini<br />
|
||||
Larry Wu<br />
|
||||
Justin Holewinski<br />
|
||||
Timothy Costa<br />
|
||||
Julien Demouth<br />
|
||||
Brian Fahs<br />
|
||||
@ -108,14 +152,11 @@ Michael Goldfarb<br />
|
||||
Mostafa Hagog<br />
|
||||
Fei Hu<br />
|
||||
Alan Kaatz<br />
|
||||
Tina Li<br />
|
||||
Wei Liu<br />
|
||||
Tim Martin<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
John Tran<br />
|
||||
Vicki Wang<br />
|
||||
Fung Xie<br />
|
||||
Yang Xu<br />
|
||||
Scott Yokim<br />
|
||||
Girish Bharambe<br />
|
||||
|
||||
@ -57,7 +57,7 @@ if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])")
|
||||
elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang")
|
||||
set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation")
|
||||
else()
|
||||
message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
|
||||
message(FATAL_ERROR "Unknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
|
||||
endif()
|
||||
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30")
|
||||
|
||||
188
EULA.txt
Normal file
188
EULA.txt
Normal file
@ -0,0 +1,188 @@
|
||||
NVIDIA Software License Agreement
|
||||
|
||||
IMPORTANT NOTICE – PLEASE READ AND AGREE BEFORE USING THE SOFTWARE
|
||||
This software license agreement (“Agreement”) is a legal agreement between you, whether an individual or entity, (“you”) and NVIDIA Corporation (“NVIDIA”) and governs the use of the NVIDIA CUTLASS DSLs software and materials that NVIDIA delivers to you under this Agreement (“Software”).
|
||||
NVIDIA and you are each a “party” and collectively the “parties.”
|
||||
This Agreement can be accepted only by an adult of legal age of majority in the country in which the Software is used.
|
||||
If you don’t have the required age or authority to accept this Agreement, or if you don’t accept all the terms and conditions of this Agreement, do not use the Software.
|
||||
|
||||
1. License Grants
|
||||
|
||||
1.1. License Grant to You. The Software made available by NVIDIA to you is licensed, not sold.
|
||||
Subject to the terms of this Agreement, NVIDIA grants you a limited, non-exclusive, revocable, non-transferable, and non-sublicensable (except as expressly granted in this Agreement), license to:
|
||||
|
||||
a. install and use copies of the Software,
|
||||
b. configure the Software using configuration files provided (if applicable),
|
||||
c. modify and create derivative works of any sample or example source code NVIDIA delivers to you as part of the Software (“Derivatives”) (if applicable), and
|
||||
d. distribute python files in the Software package in source format as incorporated into a software application subject to the following distribution requirements:
|
||||
|
||||
i. Your application must have material additional functionality, beyond the included portions of the Software.
|
||||
ii. The distributable portions of the Software shall only be accessed by your application.
|
||||
iii. The following notice shall be included in modifications and derivative works of sample source code distributed: “This software contains source code provided by NVIDIA Corporation.”
|
||||
iv. Unless a developer tool is identified in this Agreement as distributable, it is delivered for your internal use only.
|
||||
v. The terms under which you distribute your application must be consistent with the terms of this Agreement, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIA’s intellectual property rights.
|
||||
vi. Additionally, you agree that you will protect the privacy, security and legal rights of your application users.
|
||||
|
||||
The foregoing (a) through (d) are, collectively, the “Purpose”, and the developed applications are only for use in systems with NVIDIA GPUs.
|
||||
|
||||
1.2. License Grant to NVIDIA. Subject to the terms of this Agreement, you grant NVIDIA and its affiliates a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit at NVIDIA’s discretion any Derivatives created by or for you.
|
||||
You may, but are not required to, deliver any Derivatives to NVIDIA.
|
||||
|
||||
2. License Restrictions
|
||||
|
||||
Your license to use the Software and Derivatives is restricted as stated in this Section 2 (“License Restrictions”).
|
||||
You will cooperate with NVIDIA and, upon NVIDIA’s written request, you will confirm in writing and provide reasonably requested information to verify your compliance with the terms of this Agreement.
|
||||
You may not:
|
||||
|
||||
2.1. Use the Software or Derivatives for any purpose other than the Purpose;
|
||||
|
||||
2.2. Sell, rent, sublicense, transfer, distribute or otherwise make available to others (except authorized users as stated in Section 3 (“Authorized Users”)) any portion of the Software or Derivatives, except as expressly granted in Section 1.1 (“License Grant to You”);
|
||||
|
||||
2.3. Reverse engineer, decompile, or disassemble the Software components provided in binary form, nor attempt in any other manner to obtain source code of such Software;
|
||||
|
||||
2.4. Modify or create derivative works of the Software, except as expressly granted in Section 1.1 (“License Grant to You”);
|
||||
|
||||
2.5. Change or remove copyright or other proprietary notices in the Software;
|
||||
|
||||
2.6. Bypass, disable, or circumvent any technical limitation, encryption, security, digital rights management or authentication mechanism in the Software;
|
||||
|
||||
2.7. Use the Software or Derivatives in any manner that would cause them to become subject to an open source software license, subject to the terms in Section 6 (“Components Under Other Licenses”);
|
||||
|
||||
2.8. Use the Software or Derivatives in violation of any applicable law or regulation in relevant jurisdictions
|
||||
|
||||
2.9. Indicate that a product or service developed with the Software or Derivatives is sponsored or endorsed by NVIDIA;
|
||||
|
||||
2.10. Replace any NVIDIA software components in the Software that are governed by this Agreement with other software that implements NVIDIA APIs;
|
||||
|
||||
2.11. Reverse engineer, decompile or disassemble any portion of the output generated using Software elements for the purpose of translating such output artifacts to target a non-NVIDIA platform; or
|
||||
|
||||
3. Authorized Users
|
||||
|
||||
You may allow employees and contractors of your entity or of your subsidiary(ies), and for educational institutions also enrolled students, to internally access and use the Software as authorized by this Agreement from your secure network to perform the work authorized by this Agreement on your behalf.
|
||||
You are responsible for the compliance with the terms of this Agreement by your authorized users.
|
||||
Any act or omission that if committed by you would constitute a breach of this Agreement will be deemed to constitute a breach of this Agreement if committed by your authorized users.
|
||||
|
||||
4. Pre-Release
|
||||
|
||||
Software versions identified as alpha, beta, preview, early access or otherwise as pre-release (“Pre-Release”) may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability and reliability standards relative to NVIDIA commercial offerings.
|
||||
You use Pre-Release Software at your own risk. NVIDIA did not design or test the Software for use in production or business-critical systems.
|
||||
NVIDIA may choose not to make available a commercial version of Pre-Release Software.
|
||||
NVIDIA may also choose to abandon development and terminate the availability of Pre-Release Software at any time without liability.
|
||||
|
||||
5. Updates
|
||||
|
||||
NVIDIA may at any time and at its option, change, discontinue, or deprecate any part, or all, of the Software, or change or remove features or functionality, or make available patches, workarounds or other updates to the Software.
|
||||
Unless the updates are provided with their separate governing terms, they are deemed part of the Software licensed to you under this Agreement, and your continued use of the Software is deemed acceptance of such changes.
|
||||
|
||||
6. Components Under Other Licenses
|
||||
|
||||
The Software may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms (“Other Licenses”).
|
||||
The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights;
|
||||
except that this Agreement will prevail regarding the use of third-party open source software, unless a third-party open source software license requires its license terms to prevail.
|
||||
Open source software license means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (http://opensource.org), Free Software Foundation (http://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (http://www.spdx.org).
|
||||
|
||||
7. Ownership
|
||||
|
||||
7.1. NVIDIA Ownership. The Software, including all intellectual property rights, is and will remain the sole and exclusive property of NVIDIA or its licensors.
|
||||
Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Software, and (b) no other license or right is granted to you by implication, estoppel or otherwise.
|
||||
|
||||
7.2. Your Ownership. Subject to the rights of NVIDIA and its suppliers in the Software, which continue to be licensed as stated in this Agreement, even when incorporated in your products or services, and the extent permitted by applicable law, as between you and NVIDIA, you hold all rights, title and interest in and to your products, services and Derivatives you develop as permitted in this Agreement including their respective intellectual property rights.
|
||||
|
||||
8. Feedback
|
||||
|
||||
You may, but you are not obligated to, provide suggestions, requests, fixes, modifications, enhancements, or other feedback regarding the Software (collectively, “Feedback”).
|
||||
Feedback, even if designated as confidential by you, will not create any confidentiality obligation for NVIDIA or its affiliates.
|
||||
If you provide Feedback, you grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit the Feedback at NVIDIA’s discretion.
|
||||
|
||||
9. Termination
|
||||
|
||||
9.1. Termination. This Agreement will automatically terminate without notice from NVIDIA if you fail to comply with any of the terms in this Agreement or if you commence or participate in any legal proceeding against NVIDIA with respect to the Software.
|
||||
Additionally, either party may terminate this Agreement at any time with thirty (30) days’ advance written notice to the other party.
|
||||
|
||||
9.2. Effect of Termination. Upon any expiration or termination of this Agreement, you will promptly (a) stop using and return, delete or destroy NVIDIA confidential information and all Software received under this Agreement, and (b) delete or destroy Derivatives created under this Agreement, unless an authorized NVIDIA representative provides prior written approval that you may keep a copy of the Derivatives solely for archival purposes.
|
||||
Upon written request, you will certify in writing that you have complied with your obligations under this Section 9.2 (“Effect of Termination”).
|
||||
|
||||
9.3. Survival. Section 1.2 (“License Grant to NVIDIA”), Section 5 (“Updates”), Section 6 (“Components Under Other Licenses”), Section 7 (“Ownership”), Section 8 (“Feedback), Section 9.2 (“Effect of Termination”), Section 9.3 (“Survival”), Section 10 (“Disclaimer of Warranties”), Section 11 (“Limitation of Liability”), Section 12 (“Use in Mission Critical Applications”), Section 13 (“Governing Law and Jurisdiction”), Section 14 (“Indemnity”) and Section 15 (“General”) will survive any expiration or termination of this Agreement.
|
||||
|
||||
10. Disclaimer of Warranties
|
||||
|
||||
THE SOFTWARE IS PROVIDED BY NVIDIA AS-IS AND WITH ALL FAULTS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA DISCLAIMS ALL WARRANTIES AND REPRESENTATIONS OF ANY KIND, WHETHER
|
||||
EXPRESS, IMPLIED OR STATUTORY, RELATING TO OR ARISING UNDER THIS AGREEMENT, INCLUDING, WITHOUT LIMITATION, THE WARRANTIES OF TITLE, NONINFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. NVIDIA DOES NOT WARRANT OR ASSUME RESPONSIBILITY FOR THE ACCURACY OR COMPLETENESS OF ANY THIRD-PARTY INFORMATION, TEXT, GRAPHICS, LINKS CONTAINED IN THE SOFTWARE.
|
||||
WITHOUT LIMITING THE FOREGOING, NVIDIA DOES NOT WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS, ANY DEFECTS OR ERRORS WILL BE CORRECTED, ANY CERTAIN CONTENT WILL BE AVAILABLE; OR THAT THE SOFTWARE IS FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS. NO INFORMATION OR ADVICE GIVEN BY NVIDIA WILL IN ANY WAY INCREASE THE SCOPE OF ANY WARRANTY EXPRESSLY PROVIDED IN THIS AGREEMENT.
|
||||
NVIDIA does not warrant or assume responsibility for the accuracy or completeness of any third-party information, text, graphics or links contained in the Software.
|
||||
|
||||
11. Limitations of Liability
|
||||
|
||||
11.1. EXCLUSIONS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL NVIDIA BE LIABLE FOR ANY (I) INDIRECT, PUNITIVE, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES, OR (ii) DAMAGES FOR (a) THE COST OF PROCURING SUBSTITUTE GOODS, OR (b) LOSS OF PROFITS, REVENUES, USE, DATA OR GOODWILL ARISING OUT OF OR RELATED TO THIS AGREEMENT, WHETHER BASED ON BREACH OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY, OR OTHERWISE, AND EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES AND EVEN IF A PARTY’S REMEDIES FAIL THEIR ESSENTIAL PURPOSE.
|
||||
|
||||
11.2. DAMAGES CAP. ADDITIONALLY, TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA’S TOTAL CUMULATIVE AGGREGATE LIABILITY FOR ANY AND ALL LIABILITIES, OBLIGATIONS OR CLAIMS ARISING OUT OF OR RELATED TO THIS AGREEMENT WILL NOT EXCEED FIVE U.S. DOLLARS (US$5).
|
||||
|
||||
12. Use in Mission Critical Applications
|
||||
|
||||
You acknowledge that the Software provided under this Agreement is not designed or tested by NVIDIA for use in any system or application where the use or failure of such system or application developed with NVIDIA’s Software could result in injury, death or catastrophic damage (each, a “Mission Critical Application”).
|
||||
Examples of Mission Critical Applications include use in avionics, navigation, autonomous vehicle applications, AI solutions for automotive products, military, medical, life support or other mission-critical or life-critical applications.
|
||||
NVIDIA will not be liable to you or any third party, in whole or in part, for any claims or damages arising from these uses.
|
||||
You are solely responsible for ensuring that systems and applications developed with the Software include sufficient safety and redundancy features and comply with all applicable legal and regulatory standards and requirements.
|
||||
|
||||
13. Governing Law and Jurisdiction
|
||||
|
||||
This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods.
|
||||
The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts;
|
||||
except that either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
|
||||
|
||||
14. Indemnity
|
||||
|
||||
By using the Software you agree to defend, indemnify and hold harmless NVIDIA and its affiliates and their respective officers, directors, employees and agents from and against any claims, disputes, demands, liabilities, damages, losses, costs and expenses arising out of or in any way connected with (i) products or services that have been developed or deployed with or use the Software, or claims that they violate laws, or infringe, violate, or misappropriate any third party right;
|
||||
or (ii) use of the Software in breach of the terms of this Agreement.
|
||||
|
||||
15. General
|
||||
|
||||
15.1. Independent Contractors.
|
||||
The parties are independent contractors, and this Agreement does not create a joint venture, partnership, agency, or other form of business association between the parties.
|
||||
Neither party will have the power to bind the other party or incur any obligation on its behalf without the other party’s prior written consent.
|
||||
Nothing in this Agreement prevents either party from participating in similar arrangements with third parties.
|
||||
|
||||
15.2. No Assignment.
|
||||
NVIDIA may assign, delegate or transfer its rights or obligations under this Agreement by any means or operation of law.
|
||||
You may not, without NVIDIA’s prior written consent, assign, delegate or transfer any of your rights or obligations under this Agreement by any means or operation of law, and any attempt to do so is null and void.
|
||||
|
||||
15.3. No Waiver.
|
||||
No failure or delay by a party to enforce any term or obligation of this Agreement will operate as a waiver by that party, or prevent the enforcement of such term or obligation later.
|
||||
|
||||
15.4. Trade Compliance.
|
||||
You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations.
|
||||
You confirm (a) your understanding that export or reexport of certain NVIDIA products or technologies may require a license or other approval from appropriate authorities and (b) that you will not export or reexport any products or technology, directly or indirectly, without first obtaining any required license or other approval from appropriate authorities, (i) to any countries that are subject to any U.S. or local export restrictions (currently including, but not necessarily limited to, Belarus, Cuba, Iran, North Korea, Russia, Syria, the Region of Crimea, Donetsk People’s Republic Region and Luhansk People’s Republic Region);
|
||||
(ii) to any end-user who you know or have reason to know will utilize them in the design, development or production of nuclear, chemical or biological weapons, missiles, rocket systems, unmanned air vehicles capable of a maximum range of at least 300 kilometers, regardless of payload, or intended for military end-use, or any weapons of mass destruction;
|
||||
(iii) to any end-user who has been prohibited from participating in the U.S. or local export transactions by any governing authority;
|
||||
or (iv) to any known military or military-intelligence end-user or for any known military or military-intelligence end-use in accordance with U.S. trade compliance laws and regulations.
|
||||
|
||||
15.5. Government Rights.
|
||||
The Software, documentation and technology (“Protected Items”) are “Commercial products” as this term is defined at 48 C.F.R.
|
||||
2.101, consisting of “commercial computer software” and “commercial computer software documentation” as such terms are used in, respectively, 48 C.F.R.
|
||||
12.212 and 48 C.F.R. 227.7202 & 252.227-7014(a)(1). Before any Protected Items are supplied to the U.S. Government, you will (i) inform the U.S. Government in writing that the Protected Items are and must be treated as commercial computer software and commercial computer software documentation developed at private expense;
|
||||
(ii) inform the U.S. Government that the Protected Items are provided subject to the terms of the Agreement;
|
||||
and (iii) mark the Protected Items as commercial computer software and commercial computer software documentation developed at private expense.
|
||||
In no event will you permit the U.S. Government to acquire rights in Protected Items beyond those specified in 48 C.F.R.
|
||||
52.227-19(b)(1)-(2) or 252.227-7013(c) except as expressly approved by NVIDIA in writing.
|
||||
|
||||
15.6. Notices.
|
||||
Please direct your legal notices or other correspondence to legalnotices@nvidia.com with a copy mailed to NVIDIA Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, United States of America, Attention: Legal Department.
|
||||
If NVIDIA needs to contact you, you consent to receive the notices by email and agree that such notices will satisfy any legal communication requirements.
|
||||
|
||||
15.7. Severability.
|
||||
If a court of competent jurisdiction rules that a provision of this Agreement is unenforceable, that provision will be deemed modified to the extent necessary to make it enforceable and the remainder of this Agreement will continue in full force and effect.
|
||||
|
||||
15.8. Amendment.
|
||||
Any amendment to this Agreement must be in writing and signed by authorized representatives of both parties.
|
||||
|
||||
15.9. Construction.
|
||||
The headings in the Agreement are included solely for convenience and are not intended to affect the meaning or interpretation of the Agreement.
|
||||
As required by the context of the Agreement, the singular of a term includes the plural and vice versa.
|
||||
|
||||
15.10. Force Majeure.
|
||||
Neither party will be liable during any period where an event or circumstance prevents or delays that party from performing its obligations under this Agreement and that event or circumstance: (i) is not within the reasonable control of that party and is not the result of that party’s negligence, and (ii) cannot be overcome or avoided by that party using reasonably diligent efforts.
|
||||
|
||||
15.11. Entire Agreement.
|
||||
Regarding the subject matter of this Agreement, the parties agree that (a) this Agreement constitutes the entire and exclusive agreement between the parties and supersedes all prior and contemporaneous communications and (b) any additional or different terms or conditions, whether contained in purchase orders, order acknowledgments, invoices or otherwise, will not be binding and are null and void.
|
||||
|
||||
(v. May 8, 2025)
|
||||
@ -25,3 +25,10 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
Certain files within this repository are subject to separate licensing terms:
|
||||
|
||||
- The files located in the `python/CuTeDSL` directory are licensed under the
|
||||
NVIDIA End User License Agreement (EULA). Please refer to
|
||||
https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
for the full terms.
|
||||
|
||||
@ -38,9 +38,9 @@
|
||||
|
||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||
|
||||
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023.
|
||||
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, February 2023.
|
||||
|
||||
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023.
|
||||
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, February 2023.
|
||||
|
||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||
|
||||
|
||||
258
README.md
258
README.md
@ -1,109 +1,129 @@
|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 3.9.2
|
||||
# CUTLASS 4.0.0
|
||||
|
||||
_CUTLASS 3.9.2 - May 2025_
|
||||
_CUTLASS 4.0.0 - May 2025_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
|
||||
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
|
||||
these "moving parts" into reusable, modular software components abstracted by C++ template
|
||||
classes. Primitives for different levels of a conceptual parallelization hierarchy
|
||||
can be specialized and tuned via custom tiling sizes, data types,
|
||||
and other algorithmic policy. The resulting flexibility simplifies their use
|
||||
as building blocks within custom kernels and applications.
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
hierarchical decomposition and data movement. CUTLASS decomposes these "moving parts" into reusable, modular
|
||||
software components and abstractions.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
Primitives for different levels of a conceptual parallelization hierarchy can be specialized and tuned
|
||||
via custom tiling sizes, data types, and other algorithmic policy. The resulting flexibility simplifies
|
||||
their use as building blocks within custom kernels and applications.
|
||||
|
||||
CUTLASS has been providing CUDA C++ template abstractions for high-performance linear algebra since 2017 and
|
||||
these abstractions provide extensive support for a wide range of computations including
|
||||
mixed-precision computations, specialized data-movement (async copy) and
|
||||
multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16,
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
[FP32 emulation via tensor core instruction](https://github.com/NVIDIA/cutlass/tree/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
8b floating point types (e5m2 and e4m3),
|
||||
block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8),
|
||||
narrow integer types (4 and 8b signed and unsigned integers),
|
||||
and binary 1b data types (where architectures allow for the
|
||||
native support of such data types).
|
||||
CUTLASS demonstrates optimal matrix multiply operations
|
||||
native support of such data types) across NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
|
||||
|
||||
To this rich ecosystem of C++ based kernel programming abstractions, CUTLASS 4 adds CUTLASS DSLs. These are Python native interfaces for writing high-performance CUDA kernels based on core CUTLASS and CuTe concepts without any performance compromises. This allows for a much smoother learning curve, orders of magnitude faster compile times, native integration with DL frameworks without writing glue code, and much more intuitive metaprogramming that does not require deep C++ expertise.
|
||||
|
||||
Overall we envision CUTLASS DSLs as a family of domain-specific languages (DSLs). With the release of 4.0, we are releasing the first of these in CuTe DSL. This is a low level programming model that is fully consistent with CuTe C++ abstractions — exposing core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy.
|
||||
|
||||
CuTe DSL demonstrates optimal matrix multiply and other linear algebra operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
|
||||
NVIDIA's Ampere, Hopper, and Blackwell architectures.
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via
|
||||
the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
|
||||
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
We believe it will become an indispensable tool for students, researchers, and performance
|
||||
engineers alike — flattening the learning curve of GPU programming, rapidly prototyping kernel
|
||||
designs, and bringing optimized solutions into production.
|
||||
|
||||
See the [Quick Start Guide](./media/docs/cpp/quickstart.md) to get started quickly.
|
||||
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
|
||||
|
||||
See the [functionality docs](./media/docs/cpp/functionality.md) for a more comprehensive
|
||||
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
|
||||
architecture.
|
||||
To get started quickly - please refer :
|
||||
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
|
||||
|
||||
# What's New in CUTLASS 3.9
|
||||
# What's New in CUTLASS 4.0
|
||||
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
### CuTe DSL
|
||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||
- [Blackwell SM100 persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||
- [Blackwell SM100 grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||
- [Blackwell SM100 fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
|
||||
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py)
|
||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||
* API updates
|
||||
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
### CUTLASS C++
|
||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
|
||||
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
||||
- For example:
|
||||
+ `(old) cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
+ `(new) cutlass3x_sm90_tensorop_gemm_bf16_bf16_f32_bf16_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma`
|
||||
- If you are using the CUTLASS library kernel names directly (e.g. to compile a subset of the CUTLASS library with `-DCUTLASS_LIBRARY_KERNELS`, filter kernels in the CUTLASS profiler with `--kernels`), please update your uses accordingly, this is a breaking change.
|
||||
* Further improved [Blockwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) and [Groupwise](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) GEMMs on Hopper and Blackwell.
|
||||
- Added non-power-of-two tile sizes.
|
||||
- Improved performance for K-major scale factors.
|
||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell SM100 versions.
|
||||
* Enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
- Support LSE output in FMHA Forward kernel.
|
||||
- Enhance performance measurement: support of different warmup iterations; buffer rotation to keep L2 cold; separate testing of persistent and non-persistent.
|
||||
- Enhance testing of variable sequence length.
|
||||
- Disable B2B mode in MLA to simplify the sample.
|
||||
- Clarify that `fmha_gen` sample only supports head dim 128.
|
||||
- Fixes for split-kv output in MLA.
|
||||
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||
- Enable runtime datatype for Blackwell SM100 grouped GEMM. Profiler support is also added.
|
||||
- Enable kernel parameter exploration for Blackwell SM100 grouped GEMM - raster_order, swizzle.
|
||||
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
|
||||
* Add dynamic and preferred cluster support for convolution Blackwell SM100 kernels.
|
||||
* Fix profiler issues which cause no output or not supported error for some kernels.
|
||||
* Optimizations for Blackwell SM100 and SM120 block scaled kernels.
|
||||
* Support for Blackwell SM120 blockwise dense gemm in CUTLASS library and profiler.
|
||||
* New [Hopper SM90 FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* CuTe changes:
|
||||
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
|
||||
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
|
||||
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
|
||||
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||
|
||||
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.**
|
||||
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit nearly optimal utilization of peak theoretical throughput. The figure below
|
||||
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
|
||||
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
|
||||
on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg></p>
|
||||

|
||||
|
||||
The two figures below show the continual CUTLASS performance improvements
|
||||
The two figures below show the continual CUTLASS performance improvements
|
||||
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
|
||||
CUTLASS 3.1.
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||

|
||||

|
||||
|
||||
# CuTe
|
||||
|
||||
@ -125,7 +145,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md).
|
||||
[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
|
||||
|
||||
# Compatibility
|
||||
|
||||
@ -135,7 +155,7 @@ Minimum requirements:
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions.
|
||||
|
||||
@ -196,19 +216,19 @@ the kernel is expected to fail with a runtime error.
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
Or
|
||||
Or
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
```
|
||||
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/cpp/functionality.md)
|
||||
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
@ -216,22 +236,22 @@ for details on which kernels require which target architectures.
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS
|
||||
- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
|
||||
- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
|
||||
- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
@ -251,7 +271,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
||||
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -291,12 +311,12 @@ All tests should pass on supported platforms, though the exact number of tests m
|
||||
|
||||
# Project Structure
|
||||
|
||||
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
|
||||
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below.
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
@ -320,7 +340,7 @@ include/ # client applications should target this directory
|
||||
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
|
||||
|
||||
thread/ # simt code that can be performed within a CUDA thread
|
||||
|
||||
|
||||
transform/ # code specialized for layout, type, and domain transformations
|
||||
|
||||
* # core vocabulary types, containers, and basic numeric operations
|
||||
@ -345,7 +365,7 @@ include/ # client applications should target this directory
|
||||
|
||||
### CUTLASS SDK Examples
|
||||
|
||||
[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations.
|
||||
[CUTLASS SDK examples](https://github.com/NVIDIA/cutlass/tree/main/examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
### Tools
|
||||
|
||||
@ -358,9 +378,9 @@ tools/
|
||||
|
||||
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
|
||||
# CUTLASS Library
|
||||
|
||||
|
||||
util/ # CUTLASS Utilities - contains numerous helper classes for
|
||||
include/ # manging tensors in device memory, reference
|
||||
include/ # managing tensors in device memory, reference
|
||||
cutlass/ # implementations for GEMM, random initialization
|
||||
util/ # of tensors, and I/O.
|
||||
```
|
||||
@ -370,7 +390,7 @@ tools/
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
@ -384,7 +404,7 @@ $ make cutlass_profiler -j16
|
||||
|
||||
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
Beware, this results in *tens of thousands* of kernels and long build times.
|
||||
Beware, this results in *tens of thousands* of kernels and long build times.
|
||||
This would also result in a large binary size and on some platforms linker to fail on building the library.
|
||||
Therefore, it's highly recommended to generate only a subset of kernels as demonstrated in the sub-section below.
|
||||
```bash
|
||||
@ -395,13 +415,13 @@ $ make cutlass_profiler -j16
|
||||
|
||||
## Building a subset of GEMM and Convolution kernels (_reduced_ build times)
|
||||
|
||||
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
|
||||
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
|
||||
wildcard characters may be used to reduce the set of kernels. The following examples show building exactly one
|
||||
or a subset of kernels for NVIDIA Ampere and Turing architecture:
|
||||
|
||||
### Building a subset Tensor Core GEMM kernels
|
||||
|
||||
To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture,
|
||||
To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture,
|
||||
use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*gemm_f16_*_nt_align8
|
||||
@ -490,7 +510,7 @@ $ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
|
||||
### Building a subset of Tensor Core Convolution kernels
|
||||
|
||||
To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation
|
||||
To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation
|
||||
and FP16 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*fprop_optimized_f16
|
||||
@ -538,7 +558,7 @@ reference_device: Passed
|
||||
|
||||
### Building one Convolution CUDA kernel
|
||||
|
||||
To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation
|
||||
To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation
|
||||
and FP32 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
|
||||
@ -586,14 +606,14 @@ reference_device: Passed
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md)
|
||||
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
|
||||
|
||||
|
||||
# About
|
||||
|
||||
CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
[3-clause "New" BSD license](LICENSE.txt).
|
||||
|
||||
# Contributors
|
||||
|
||||
@ -36,7 +36,7 @@ set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
function(cutlass_generate_kernel_filter_and_testlists_files)
|
||||
function(cutlass_generate_kernel_filter_and_testlist_files)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs TEST_SET_NAME)
|
||||
@ -59,30 +59,30 @@ function(cutlass_generate_kernel_filter_and_testlists_files)
|
||||
)
|
||||
|
||||
if(NOT cutlass_FILTER_GENERATION_RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "Error generating kernel filters and testlists files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
|
||||
message(FATAL_ERROR "Error generating kernel filters and testlist files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a 101a 120a)
|
||||
set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f)
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM100a/101a/120a compute capability is supported with profiler-based unit tests")
|
||||
message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 0)
|
||||
|
||||
message(STATUS "Building for L0 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l0)
|
||||
cutlass_generate_kernel_filter_and_testlist_files(TEST_SET_NAME kernel_testlist_l0)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
elseif (CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 1)
|
||||
|
||||
message(STATUS "Building for L1 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l1)
|
||||
cutlass_generate_kernel_filter_and_testlist_files(TEST_SET_NAME kernel_testlist_l1)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
|
||||
@ -489,7 +489,7 @@ int run(Options &options)
|
||||
std::cout << " Batches : " << options.l << std::endl;
|
||||
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -124,7 +124,7 @@ struct CooperativeConfig {
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_256,_128,_128>;
|
||||
using ClusterShape = Shape<_2,_2,_1>;
|
||||
using ClusterShape = Shape<_1,_2,_1>;
|
||||
};
|
||||
|
||||
struct PingpongConfig {
|
||||
@ -296,14 +296,14 @@ struct Options {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
if (m < 0) {
|
||||
m = alignment * ((rand() % 64));
|
||||
}
|
||||
if (n < 1) {
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
if (n < 0) {
|
||||
n = alignment * ((rand() % 64));
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
if (k < 0) {
|
||||
k = alignment * ((rand() % 64));
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
@ -333,19 +333,9 @@ struct Options {
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
extent.at(i) = std::atoi(tokens.at(i).c_str());
|
||||
}
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
@ -500,10 +490,27 @@ void initialize(const Options &options) {
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
// If the current group's matrix has size 0, set the pointer to nullptr
|
||||
if (i < options.groups - 1 && offset_A.at(i) == offset_A.at(i + 1)) {
|
||||
ptr_A_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_B.at(i) == offset_B.at(i + 1)) {
|
||||
ptr_B_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_C.at(i) == offset_C.at(i + 1)) {
|
||||
ptr_C_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
}
|
||||
if (i < options.groups - 1 && offset_D.at(i) == offset_D.at(i + 1)) {
|
||||
ptr_D_host.at(i) = nullptr;
|
||||
} else {
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
}
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
@ -539,9 +546,10 @@ void initialize(const Options &options) {
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_A, seed + 2021);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
initialize_block(block_C, seed + 2023);
|
||||
initialize_block(block_D, seed + 2024);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
}
|
||||
@ -653,6 +661,13 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
GemmT gemm;
|
||||
|
||||
@ -700,14 +715,8 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -770,9 +770,6 @@ int main(int argc, char const** argv) {
|
||||
|
||||
bool satisfied;
|
||||
if (props.major < 10) {
|
||||
// Pre-Blackwell
|
||||
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4);
|
||||
satisfied &= (props.major > 8) || (props.major == 8 && props.minor == 9);
|
||||
}
|
||||
else {
|
||||
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8);
|
||||
@ -786,7 +783,6 @@ int main(int argc, char const** argv) {
|
||||
std::cout
|
||||
<< "CUTLASS's FP8 SM89 example requires an NVIDIA GPU with compute capability 8.9 or greater "
|
||||
<< "and CUDA toolkit version 12.4 or later"
|
||||
<< " (12.8 or later needed for SM100+)"
|
||||
<< std::endl;
|
||||
|
||||
return 0;
|
||||
|
||||
@ -42,7 +42,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/arch/grid_dependency_control.h"
|
||||
|
||||
@ -288,7 +287,7 @@ struct CollectiveMma<
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
@ -445,7 +444,7 @@ struct CollectiveMma<
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
@ -453,7 +452,7 @@ struct CollectiveMma<
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
@ -533,7 +532,7 @@ struct CollectiveMma<
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
@ -541,7 +540,7 @@ struct CollectiveMma<
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
@ -634,9 +633,9 @@ struct CollectiveMma<
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
@ -854,7 +853,7 @@ struct CollectiveMma<
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
|
||||
@ -132,8 +132,8 @@ using namespace cute;
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
@ -252,7 +252,8 @@ HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
@ -344,8 +345,8 @@ struct Result {
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -803,17 +804,18 @@ int run(Options &options) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example
|
||||
// and must have compute capability at least 90.
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.6.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
|
||||
std::cerr << "This example requires CUDA 12.6 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
@ -857,8 +859,12 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
|
||||
run(options);
|
||||
#else
|
||||
std::cerr
|
||||
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl;
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -205,7 +205,6 @@ cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
@ -405,12 +404,6 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.resize(c_coord);
|
||||
tensor_aux.sync_device();
|
||||
@ -470,7 +463,6 @@ typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &op
|
||||
stride_A,
|
||||
tensor_B.device_data(),
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
layout_SFA,
|
||||
blockscale_tensor_B.device_data(),
|
||||
|
||||
@ -215,7 +215,6 @@ cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
@ -413,12 +412,6 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.resize(c_coord);
|
||||
tensor_aux.sync_device();
|
||||
@ -479,7 +472,6 @@ GemmArguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
stride_A,
|
||||
tensor_B.device_data(),
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
layout_SFA,
|
||||
blockscale_tensor_B.device_data(),
|
||||
|
||||
@ -250,8 +250,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
@ -518,7 +516,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
arguments.scheduler.raster_order = options.raster_order;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
@ -690,10 +688,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
@ -747,7 +745,7 @@ int main(int argc, char const **args) {
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
Options<ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
|
||||
@ -253,8 +253,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
@ -523,7 +521,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
arguments.scheduler.raster_order = options.raster_order;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
@ -699,10 +697,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
@ -755,7 +753,7 @@ int main(int argc, char const **args) {
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
Options<ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
|
||||
@ -30,10 +30,9 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename _RasterOrderOptions, typename _ProblemShape>
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
template<typename _ProblemShape>
|
||||
struct Options {
|
||||
|
||||
using RasterOrderOptions = _RasterOrderOptions;
|
||||
using ProblemShape = _ProblemShape;
|
||||
|
||||
bool help = false;
|
||||
@ -50,7 +49,7 @@ struct Options {
|
||||
int const m_alignment = 128;
|
||||
int const n_alignment = 128;
|
||||
|
||||
RasterOrderOptions raster;
|
||||
RasterOrderOptions raster_order;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
@ -74,13 +73,13 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
raster_order = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
|
||||
@ -543,7 +543,7 @@ int run(Options &options) {
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
@ -560,7 +560,6 @@ int main(int argc, char const **args) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -237,7 +237,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
@ -354,19 +354,9 @@ struct Options {
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
extent.at(i) = std::atoi(tokens.at(i).c_str());
|
||||
}
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
@ -745,7 +735,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -124,6 +124,7 @@ constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // A
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
|
||||
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
|
||||
|
||||
using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands
|
||||
constexpr int OutputSFVectorSize = 16;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
@ -299,7 +300,7 @@ auto make_iterator(T* ptr) {
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
@ -422,19 +423,9 @@ struct Options {
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
extent.at(i) = std::atoi(tokens.at(i).c_str());
|
||||
}
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
@ -885,7 +876,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -490,7 +490,7 @@ int run(Options &options)
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -490,7 +490,7 @@ int run(Options &options)
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -499,11 +499,11 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -116,16 +116,20 @@ struct Options {
|
||||
int h_k = 1;
|
||||
int q = 256;
|
||||
int k = 256;
|
||||
std::vector<int> varlen_q;
|
||||
std::vector<int> varlen_k;
|
||||
int d = 128;
|
||||
int warmup_iterations = 1;
|
||||
int iterations = 3;
|
||||
int tensor_ring_buffers = 1;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
bool persistent = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
@ -179,20 +183,87 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
|
||||
std::string varlen_q_str;
|
||||
cmd.get_cmd_line_argument("varlen-q", varlen_q_str);
|
||||
std::string varlen_k_str;
|
||||
cmd.get_cmd_line_argument("varlen-k", varlen_k_str);
|
||||
|
||||
if (varlen && ! varlen_q_str.empty()) {
|
||||
varlen_q.clear();
|
||||
while (! varlen_q_str.empty()) {
|
||||
size_t pos = varlen_q_str.find(':');
|
||||
varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos)));
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
varlen_q_str = varlen_q_str.substr(pos + 1);
|
||||
}
|
||||
if (b == -1) {
|
||||
b = static_cast<int>(varlen_q.size());
|
||||
}
|
||||
if (b != static_cast<int>(varlen_q.size())) {
|
||||
std::cout << "Error: Invalid --varlen-q length\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
int new_q = 0;
|
||||
for (auto elem : varlen_q) {
|
||||
new_q += elem;
|
||||
}
|
||||
if (q != -1) {
|
||||
std::cout << "Error: Can't provide --q and --varlen-q\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
q = new_q;
|
||||
}
|
||||
|
||||
if (varlen && ! varlen_k_str.empty()) {
|
||||
varlen_k.clear();
|
||||
while (! varlen_k_str.empty()) {
|
||||
size_t pos = varlen_k_str.find(':');
|
||||
varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos)));
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
varlen_k_str = varlen_k_str.substr(pos + 1);
|
||||
}
|
||||
if (b == -1) {
|
||||
b = static_cast<int>(varlen_k.size());
|
||||
}
|
||||
if (b != static_cast<int>(varlen_k.size())) {
|
||||
std::cout << " Error: Invalid --varlen-k length\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
int new_k = 0;
|
||||
for (auto elem : varlen_k) {
|
||||
new_k += elem;
|
||||
}
|
||||
if (k != -1) {
|
||||
std::cout << "Error: Can't provide --k and --varlen-k\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
k = new_k;
|
||||
}
|
||||
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
|
||||
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
persistent = cmd.check_cmd_line_flag("persistent");
|
||||
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
@ -210,7 +281,6 @@ struct Options {
|
||||
causal = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||
@ -234,11 +304,16 @@ struct Options {
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --varlen-q=<int>:<int...> Sets the variable Q extent per batch (colon separated)\n"
|
||||
<< " --varlen-k=<int>:<int...> Sets the variable K extent per batch (colon separated)\n"
|
||||
<< " --d=<int> Sets the D extent\n"
|
||||
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
|
||||
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --persistent Enables persistent scheduler\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
@ -379,40 +454,55 @@ struct FwdRunner {
|
||||
StrideLSE stride_LSE;
|
||||
uint64_t seed = 0;
|
||||
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
struct DeviceBuffer {
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
DeviceBuffer() = default;
|
||||
DeviceBuffer(const DeviceBuffer&) = delete;
|
||||
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
|
||||
|
||||
size_t get_storage_size() const {
|
||||
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
|
||||
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
|
||||
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
|
||||
+ device_cumulative_seqlen_kv.get_storage_size();
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
|
||||
|
||||
std::vector<int> cumulative_seqlen_q;
|
||||
std::vector<int> cumulative_seqlen_kv;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
@ -431,7 +521,7 @@ struct FwdRunner {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
@ -439,20 +529,22 @@ struct FwdRunner {
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
bool passed_LSE = true; // future work
|
||||
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
// if ( ! passed_LSE) {
|
||||
// std::cerr << "failed LSE: max diff " << max_diff
|
||||
// << " mean " << mean_diff << std::endl;
|
||||
// }
|
||||
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_LSE) {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
|
||||
auto initialize_varlen(
|
||||
const Options& options, const ProblemShape& problem_size,
|
||||
const bool kVarlenSame = true) {
|
||||
|
||||
int num_batches = get<3,1>(problem_size);
|
||||
|
||||
// generate Q as --b times
|
||||
@ -480,8 +572,12 @@ struct FwdRunner {
|
||||
int max_seqlen_kv = 0;
|
||||
|
||||
for (int i = 0; i < num_batches; i++) {
|
||||
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
|
||||
int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) :
|
||||
kVarlenSame ? get<0>(problem_size) :
|
||||
generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) :
|
||||
kVarlenSame ? get<1>(problem_size) :
|
||||
generate_positive_int(dist_kv, rng);
|
||||
|
||||
total_seqlen_q += seqlen_q;
|
||||
total_seqlen_kv += seqlen_kv;
|
||||
@ -522,7 +618,7 @@ struct FwdRunner {
|
||||
decltype(problem_shape_in) problem_size;
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
|
||||
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(options, problem_shape_in);
|
||||
problem_shape = problem_shape_launch;
|
||||
problem_size = problem_shape_init;
|
||||
}
|
||||
@ -559,50 +655,72 @@ struct FwdRunner {
|
||||
get<1,1>(stride_LSE) = 0;
|
||||
}
|
||||
|
||||
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_LSE.reset(size(shape_LSE));
|
||||
block_ref_O.reset(size(shape_QO));
|
||||
block_ref_LSE.reset(size(shape_LSE));
|
||||
auto buffer_init_fn = [&](auto& buffer) {
|
||||
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_LSE.reset(size(shape_LSE));
|
||||
buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
buffer.block_ref_LSE.reset(size(shape_LSE));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
|
||||
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
buffer.device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
buffer.device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
}
|
||||
};
|
||||
|
||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||
buffer_init_fn(*buffers.back());
|
||||
|
||||
int tensor_ring_buffers = options.tensor_ring_buffers;
|
||||
for (int i = 1; i < tensor_ring_buffers; i++) {
|
||||
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||
buffer_init_fn(*buffers.back());
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
||||
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
|
||||
auto problem_shape_ = problem_shape;
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape_,
|
||||
{ buffers[buffer_index]->block_Q.get(), stride_Q,
|
||||
buffers[buffer_index]->block_K.get(), stride_K,
|
||||
buffers[buffer_index]->block_V.get(), stride_V },
|
||||
{ buffers[buffer_index]->block_O.get(), stride_O,
|
||||
buffers[buffer_index]->block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
return arguments;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShapeType problem_shape = initialize(options);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V },
|
||||
{ block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
int buffer_index = 0;
|
||||
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
|
||||
Operation op;
|
||||
|
||||
@ -630,11 +748,21 @@ struct FwdRunner {
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
for (int i = 0; i < options.warmup_iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
buffer_index = (buffer_index + 1) % buffers.size();
|
||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
status = op.update(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||
<< std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
@ -672,6 +800,14 @@ struct FwdRunner {
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
buffer_index = (buffer_index + 1) % buffers.size();
|
||||
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||
status = op.update(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||
<< std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -734,10 +870,10 @@ struct FwdRunner {
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
passed = verify(problem_shape, *buffers[0]);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
@ -752,11 +888,18 @@ struct FwdRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
@ -789,10 +932,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -818,10 +965,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -845,10 +996,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
||||
using HeadDim = _32;
|
||||
|
||||
#ifdef FP8
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
if (options.persistent) {
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
}
|
||||
else {
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -945,7 +1100,7 @@ int main_single(int argc, char const **args) {
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -953,8 +1108,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -981,7 +1134,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -689,11 +689,18 @@ struct ExampleRunner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
|
||||
if (result.supported && ! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
|
||||
@ -781,12 +788,17 @@ int main_single(int argc, char const **args) {
|
||||
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
|
||||
)
|
||||
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
if (options.d == 128) {
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
}
|
||||
else {
|
||||
std::cout << "Head Dimension != 128 is not supported for the fmha_gen example\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
@ -797,8 +809,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -825,7 +835,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -98,6 +98,9 @@ struct Options {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "l") {
|
||||
dst = InitStyle::kRandomLarge;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
@ -203,6 +206,11 @@ void initialize_block(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandomLarge: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 100);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
@ -383,11 +391,7 @@ struct Runner {
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
#ifdef B2B
|
||||
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#else
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#endif
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
@ -396,7 +400,6 @@ struct Runner {
|
||||
}
|
||||
|
||||
bool passed_LSE = true;
|
||||
#ifndef B2B
|
||||
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
@ -404,7 +407,6 @@ struct Runner {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
@ -670,11 +672,18 @@ struct Runner {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_result = 0;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
if (! result.passed) {
|
||||
main_result = -1;
|
||||
}
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
|
||||
@ -798,8 +807,6 @@ int main_single(int argc, char const **args) {
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
@ -826,7 +833,7 @@ int main(int argc, char const **args) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
return main_result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -43,14 +43,30 @@ set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --v
|
||||
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
|
||||
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
|
||||
|
||||
set(TEST_VARLEN_00 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_01 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_02 --verify --varlen --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128)
|
||||
set(TEST_VARLEN_03 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_04 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_05 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512)
|
||||
set(TEST_VARLEN_06 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512)
|
||||
set(TEST_VARLEN_07 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512)
|
||||
set(TEST_VARLEN_08 --verify --varlen --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512)
|
||||
set(TEST_VARLEN_09 --verify --varlen --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300)
|
||||
set(TEST_VARLEN_10 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5)
|
||||
set(TEST_VARLEN_11 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10)
|
||||
set(TEST_VARLEN_12 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766)
|
||||
set(TEST_VARLEN_13 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766)
|
||||
set(TEST_VARLEN_14 --verify --varlen --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1)
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --page=128 --verify)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
@ -62,10 +78,25 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
TEST_CAUSAL
|
||||
TEST_VARLEN
|
||||
TEST_HDIM64
|
||||
TEST_GQA
|
||||
TEST_VARLEN_00
|
||||
TEST_VARLEN_01
|
||||
TEST_VARLEN_02
|
||||
TEST_VARLEN_03
|
||||
TEST_VARLEN_04
|
||||
TEST_VARLEN_05
|
||||
TEST_VARLEN_06
|
||||
TEST_VARLEN_07
|
||||
TEST_VARLEN_08
|
||||
TEST_VARLEN_09
|
||||
TEST_VARLEN_10
|
||||
TEST_VARLEN_11
|
||||
TEST_VARLEN_12
|
||||
TEST_VARLEN_13
|
||||
TEST_VARLEN_14
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -75,11 +106,11 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
TEST_GEN_GQA
|
||||
TEST_GEN_REMAP
|
||||
TEST_GEN_CACHEONLY
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -104,26 +135,12 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_b2b_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
|
||||
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
TEST_VARLEN
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
@ -144,4 +161,19 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
|
||||
# Add a target that builds all examples
|
||||
add_custom_target(77_blackwell_fmha_all
|
||||
DEPENDS
|
||||
77_blackwell_fmha_fp8
|
||||
77_blackwell_fmha_fp16
|
||||
77_blackwell_fmha_gen_fp8
|
||||
77_blackwell_fmha_gen_fp16
|
||||
77_blackwell_mla_2sm_fp8
|
||||
77_blackwell_mla_2sm_fp16
|
||||
77_blackwell_mla_2sm_cpasync_fp8
|
||||
77_blackwell_mla_2sm_cpasync_fp16
|
||||
77_blackwell_fmha_bwd_fp8
|
||||
77_blackwell_fmha_bwd_fp16
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -157,7 +157,8 @@ struct CausalMask : NoMask {
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
|
||||
@ -42,7 +42,7 @@ template<
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE // Q, B
|
||||
class StrideLSE_ // Q, B
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
@ -54,6 +54,8 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using StrideLSE = StrideLSE_;
|
||||
using ElementOut = Element;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
@ -79,6 +81,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
struct Params {
|
||||
TMA_O tma_store_o;
|
||||
|
||||
ElementAcc* ptr_LSE;
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
@ -110,7 +115,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
);
|
||||
|
||||
return {
|
||||
tma_store_o
|
||||
tma_store_o,
|
||||
args.ptr_LSE,
|
||||
args.dLSE
|
||||
};
|
||||
}
|
||||
|
||||
@ -119,6 +126,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||
}
|
||||
|
||||
const Params& params;
|
||||
|
||||
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
|
||||
@ -505,12 +505,12 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
|
||||
}
|
||||
|
||||
template<bool need_apply_mask, class Stage, class BlkCoord, class CountingTensor, class ProblemShape>
|
||||
template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
softmax_step(
|
||||
float& row_max, float& row_sum,
|
||||
Stage stage, bool final_call,
|
||||
BlkCoord const& blk_coord, CountingTensor const& cS,
|
||||
BlkCoord const& blk_coord, CoordTensor const& cS,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
|
||||
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
|
||||
@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
// Each thread owns a single row
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||
|
||||
@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
|
||||
|
||||
const int kReleasePipeCount = 10; // must be multiple of 2
|
||||
|
||||
|
||||
order_s.wait();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
||||
|
||||
|
||||
|
||||
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
||||
order_s.arrive();
|
||||
}
|
||||
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
||||
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
||||
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
||||
|
||||
|
||||
row_sum = local_row_sum;
|
||||
|
||||
if (final_call) {
|
||||
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
||||
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
||||
|
||||
|
||||
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
||||
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
||||
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
||||
|
||||
|
||||
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
||||
|
||||
|
||||
#ifndef ONLY_SOFTMAX
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(tTMrO); j += 2) {
|
||||
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// good values would be either 32 or 64
|
||||
const int kCorrectionTileSize = 16;
|
||||
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma;
|
||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||
|
||||
|
||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||
|
||||
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
||||
|
||||
|
||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
||||
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
||||
|
||||
|
||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
||||
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
float2 scale_f32x2 = make_float2(scale, scale);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||
|
||||
|
||||
auto copy_in = [&](int i) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
@ -942,16 +942,21 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
|
||||
template<
|
||||
class BlkCoord, class ProblemShape, class ParamsProblemShape,
|
||||
class TensorStorageEpi, class CollectiveEpilogue
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
correction(
|
||||
BlkCoord const& blk_coord,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorageEpi& shared_storage_epi,
|
||||
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
||||
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
||||
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||
CollectiveEpilogue& epilogue) {
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||
|
||||
@ -961,7 +966,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
||||
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
||||
|
||||
|
||||
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||
|
||||
@ -1060,13 +1065,30 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
// F2FP
|
||||
// store to smem
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||
++pipeline_o_consumer_state;
|
||||
|
||||
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
|
||||
@ -1083,6 +1105,21 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
|
||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||
|
||||
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||
@ -1092,6 +1129,85 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
++pipeline_epi_producer_state;
|
||||
}
|
||||
|
||||
|
||||
template<
|
||||
class BlkCoord, class ProblemShape, class ParamsProblemShape,
|
||||
class TensorStorageEpi, class CollectiveEpilogue
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
correction_empty(
|
||||
BlkCoord const& blk_coord,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorageEpi& shared_storage_epi,
|
||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||
CollectiveEpilogue& epilogue) {
|
||||
|
||||
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
|
||||
float lse = -INFINITY;
|
||||
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n")
|
||||
if (threadIdx.x % 128 == 0 && block0()) {
|
||||
DSHOW(sO);
|
||||
}
|
||||
#if 1
|
||||
|
||||
using ElementOut = typename CollectiveEpilogue::ElementOut;
|
||||
auto tiled_copy = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
|
||||
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
|
||||
sO.layout());
|
||||
|
||||
auto thr_copy = tiled_copy.get_slice(thread_idx);
|
||||
auto tOgO = thr_copy.partition_D(sO);
|
||||
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
|
||||
clear(tOrO);
|
||||
|
||||
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
|
||||
#endif
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
|
||||
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
|
||||
|
||||
if (epilogue.params.ptr_LSE != nullptr) {
|
||||
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||
|
||||
int row_offset = 0;
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
|
||||
}
|
||||
|
||||
if (row_idx < get<0>(problem_shape)) {
|
||||
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||
++pipeline_epi_producer_state;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
@ -514,12 +514,12 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
|
||||
}
|
||||
|
||||
template<bool need_apply_mask, class Stage, class BlkCoord, class CountingTensor, class ProblemShape>
|
||||
template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
softmax_step(
|
||||
float& row_max, float& row_sum,
|
||||
Stage stage, bool final_call,
|
||||
BlkCoord const& blk_coord, CountingTensor const& cS,
|
||||
BlkCoord const& blk_coord, CoordTensor const& cS,
|
||||
Params const& params, ProblemShape const& problem_shape,
|
||||
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
|
||||
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
|
||||
@ -831,7 +831,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
// loop:
|
||||
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
|
||||
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
|
||||
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
|
||||
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
|
||||
@ -917,7 +917,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
|
||||
float2 scale_f32x2 = make_float2(scale, scale);
|
||||
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
|
||||
|
||||
auto copy_in = [&](int i) {
|
||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||
|
||||
@ -170,8 +170,8 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
auto tSgQ = thr_mma_qk.partition_A(gQ);
|
||||
auto tScQ = thr_mma_qk.partition_A(cQ);
|
||||
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, _16>, Stride<Stride<_16, _32>, _1>>{};
|
||||
|
||||
auto tiled_copy_q = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},
|
||||
|
||||
@ -90,8 +90,8 @@ struct PersistentTileScheduler {
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_h;
|
||||
FastDivmod divmod_b;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
@ -146,7 +146,7 @@ struct PersistentTileScheduler {
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||
return make_coord(m_block, _0{}, make_coord(bidh, bidb));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -118,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||
|
||||
|
||||
// compute S
|
||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
|
||||
SmemLayoutDQ{}(_, _, _0{})
|
||||
);
|
||||
|
||||
|
||||
return Params{
|
||||
args.problem_shape,
|
||||
args.mainloop,
|
||||
@ -452,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
||||
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
||||
|
||||
|
||||
auto tSTgK = cta_mma_kq.partition_A(gK);
|
||||
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
||||
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
||||
@ -477,7 +477,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
||||
|
||||
// set up lse and sum_odo
|
||||
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
|
||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
// load Q
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -520,7 +520,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
&mLSE(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
|
||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_lse_producer_state;
|
||||
|
||||
@ -529,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||
|
||||
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
||||
|
||||
|
||||
// load V
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
@ -540,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -573,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||
|
||||
// load Q
|
||||
if (cute::elect_one_sync()) {
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -584,7 +584,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
++pipeline_load_mma_q_producer_state;
|
||||
|
||||
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
||||
|
||||
|
||||
// load LSE
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
@ -593,15 +593,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
&mLSE(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
|
||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_lse_producer_state;
|
||||
|
||||
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
// load dO
|
||||
if (cute::elect_one_sync()) {
|
||||
cute::copy(
|
||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||
@ -612,7 +612,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
++pipeline_load_mma_do_producer_state;
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
||||
|
||||
|
||||
// load sum_OdO
|
||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||
@ -621,7 +621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||
gmem_idx < Q
|
||||
);
|
||||
|
||||
|
||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_load_compute_sum_odo_producer_state;
|
||||
|
||||
@ -639,23 +639,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
int iter_count,
|
||||
MainloopArguments const& mainloop_args,
|
||||
TensorStorage& shared_tensors,
|
||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
||||
@ -685,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
||||
tDVrP.data() = TmemAllocation::kP;
|
||||
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
||||
|
||||
|
||||
TiledMmaKQ tiled_mma_kq;
|
||||
TiledMmaVDO tiled_mma_vdo;
|
||||
TiledMmaDSK tiled_mma_dsk;
|
||||
@ -923,6 +923,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
TensorC const& coord,
|
||||
TensorShape const& tensor_shape) {
|
||||
|
||||
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
|
||||
|
||||
auto copy_op = make_cotiled_copy(
|
||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
||||
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
||||
@ -930,21 +932,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
);
|
||||
auto thr_copy = copy_op.get_slice(_0{});
|
||||
|
||||
auto tCg = thr_copy.partition_D(gmem);
|
||||
auto tCr = thr_copy.partition_S(quantize(regs));
|
||||
auto tCc = thr_copy.partition_D(coord);
|
||||
Tensor tCg = thr_copy.partition_D(gmem);
|
||||
Tensor tCr = thr_copy.partition_S(quantize(regs));
|
||||
Tensor tPc = thr_copy.partition_D(preds);
|
||||
|
||||
constexpr int R = decltype(tCr.layout())::rank;
|
||||
auto tCg_v = group_modes<1, R>(tCg);
|
||||
auto tCr_v = group_modes<1, R>(tCr);
|
||||
auto tCc_v = group_modes<1, R>(tCc);
|
||||
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
|
||||
|
||||
for (int i = 0; i < size(tCp_v); ++i) {
|
||||
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
|
||||
}
|
||||
|
||||
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
|
||||
copy_if(copy_op, tPc, tCr, tCg);
|
||||
}
|
||||
|
||||
|
||||
@ -1073,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
// in tmem, S & P overlap
|
||||
@ -1114,7 +1106,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
||||
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
||||
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
||||
|
||||
|
||||
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
||||
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
||||
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
||||
@ -1152,20 +1144,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
fn(cute::false_type{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
|
||||
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
|
||||
|
||||
// compute P = softmax(S, LSE)
|
||||
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
||||
|
||||
|
||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
|
||||
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
||||
auto c_transpose = tTR_cST(i);
|
||||
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
||||
}, problem_shape);
|
||||
}
|
||||
|
||||
|
||||
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
||||
float2 softmax_scale_log2_e;
|
||||
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
||||
@ -1184,16 +1176,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tTR_rST(i) = ::exp2f(out.x);
|
||||
tTR_rST(i+1) = ::exp2f(out.y);
|
||||
}
|
||||
|
||||
|
||||
auto tRT_rST = quantize(tTR_rST);
|
||||
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
||||
|
||||
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
cutlass::arch::NamedBarrier(
|
||||
kNumComputeWarps * NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
||||
).arrive_and_wait();
|
||||
|
||||
|
||||
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
||||
});
|
||||
|
||||
@ -1293,9 +1285,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
||||
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
||||
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
||||
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
|
||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||
@ -1307,7 +1299,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
tDQtDQ.data() = TmemAllocation::kDQ;
|
||||
|
||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
||||
(_, _, _, _0{}, blk_coord_batch);
|
||||
|
||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||
@ -1376,7 +1368,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
@ -1561,7 +1553,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
||||
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
||||
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
||||
|
||||
|
||||
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
||||
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
||||
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
||||
@ -1587,7 +1579,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
||||
|
||||
|
||||
load(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1596,7 +1588,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
params.mainloop,
|
||||
params.mainloop_params,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
||||
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
||||
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
||||
@ -1608,7 +1600,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
|
||||
mma(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1616,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
iter_count,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
||||
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
||||
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
||||
@ -1629,7 +1621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else if (role == WarpRole::Compute) {
|
||||
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
||||
|
||||
|
||||
compute(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1660,7 +1652,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else if (role == WarpRole::Reduce) {
|
||||
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
||||
|
||||
|
||||
reduce(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
@ -1677,9 +1669,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
||||
}
|
||||
else {
|
||||
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
||||
|
||||
|
||||
/* no-op */
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue;
|
||||
CollectiveEpilogue epilogue{params.epilogue};
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
@ -372,6 +372,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_softmax_0 = role == WarpRole::Softmax0;
|
||||
|
||||
mainloop.softmax(
|
||||
@ -400,17 +404,30 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
mainloop.correction_empty(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.correction(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
@ -438,6 +455,9 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
@ -450,7 +470,6 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
pipeline_mma_corr, pipeline_mma_corr_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Load) {
|
||||
@ -467,6 +486,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get<1>(logical_problem_shape) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
|
||||
@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
@ -784,7 +784,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
auto pages_per_tile = Pow2{TileShapeS{} / page_size};
|
||||
int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp;
|
||||
|
||||
#if 1
|
||||
for (; k_tile_count > 0; ++k_index, --k_tile_count) {
|
||||
pipeline_page_table.producer_acquire(pipeline_pt_producer_state);
|
||||
|
||||
@ -805,7 +804,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_pt_producer_state;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -1639,7 +1637,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]);
|
||||
}
|
||||
|
||||
#ifndef B2B
|
||||
// find correction factor
|
||||
ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast<ElementAcc>(M_LOG2E);
|
||||
correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new));
|
||||
@ -1651,7 +1648,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (int i = 0; i < size(tTR_rAcc); i++) {
|
||||
tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2);
|
||||
}
|
||||
#endif
|
||||
|
||||
// quantize
|
||||
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
|
||||
@ -1705,7 +1701,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
uint32_t tmem_o) {
|
||||
|
||||
// for b2b gemm, do nothing
|
||||
#ifndef B2B
|
||||
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
|
||||
auto store_op = TMEM::tmem_load_to_store(load_op);
|
||||
|
||||
@ -1748,7 +1743,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
// store o
|
||||
copy(tiled_r2t, tTR_rAcc, tTR_tAcc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -1806,8 +1800,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
#ifndef B2B
|
||||
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
|
||||
@ -1819,7 +1811,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
{
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
|
||||
@ -1848,7 +1839,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
copy(tTR_rO_src, tR2G_rO_dst);
|
||||
|
||||
#ifndef B2B
|
||||
|
||||
if (epilogue_args.ptr_lse != nullptr) {
|
||||
// compute LSE
|
||||
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
|
||||
@ -1863,7 +1854,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
gLSE(threadIdx.x) = lse;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -1980,9 +1970,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
|
||||
pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state);
|
||||
|
||||
#ifdef B2B
|
||||
row_sum = 1;
|
||||
#else
|
||||
if constexpr (kWarpsInN > 1) {
|
||||
// reduce row_sum if needed (for 2x2 dp)
|
||||
shared_tensors.smem_exchange[threadIdx.x] = row_sum;
|
||||
@ -1991,7 +1978,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
int peer_index = (threadIdx.x + 64) % 128;
|
||||
row_sum += shared_tensors.smem_exchange[peer_index];
|
||||
}
|
||||
#endif
|
||||
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
|
||||
|
||||
|
||||
@ -80,6 +80,17 @@ void __global__ fmha_reference_kernel(
|
||||
if constexpr (rank<1>(decltype(coord){}) == 2) {
|
||||
offset_K = get<1,1>(coord);
|
||||
}
|
||||
|
||||
if (get<1>(problem_shape) == 0) {
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = -INFINITY;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
@ -127,7 +138,7 @@ void __global__ fmha_reference_kernel(
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
||||
}
|
||||
|
||||
|
||||
@ -111,11 +111,9 @@ void __global__ fmha_mla_reference_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifndef B2B
|
||||
for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) {
|
||||
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
|
||||
}
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -125,9 +123,6 @@ void __global__ fmha_mla_reference_kernel(
|
||||
}
|
||||
|
||||
ElementAcc o_scale = 1.0f / sum;
|
||||
#ifdef B2B
|
||||
o_scale = 1.0;
|
||||
#endif
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) {
|
||||
ElementAcc acc = 0;
|
||||
|
||||
@ -75,6 +75,8 @@ struct DeviceAllocation {
|
||||
|
||||
size_t size() const { return size_; }
|
||||
|
||||
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
|
||||
|
||||
void copy_from_host(const T* ptr, size_t sz) {
|
||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||
assert(ret == cudaSuccess);
|
||||
@ -99,8 +101,12 @@ __global__ void reference_abs_diff_kernel(
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
if (data[i] == data_ref[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
double diff = fabs(data[i] - data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
@ -192,8 +198,11 @@ __global__ void reference_rel_diff_kernel(
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
if (data[i] == data_ref[i]) {
|
||||
continue;
|
||||
}
|
||||
double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
|
||||
@ -280,7 +280,7 @@ auto make_iterator(T* ptr) {
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
@ -861,7 +861,7 @@ int run(Options &options, bool host_problem_shapes_available = true)
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
std::cout << " TFLOPS : " << result.gflops / 1000.0 << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -132,8 +132,8 @@ using namespace cute;
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
@ -254,7 +254,8 @@ HostTensorB tensor_B_arr[TP_];
|
||||
HostTensorD tensor_C_arr[TP_];
|
||||
HostTensorD tensor_D_arr[TP_];
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
@ -346,8 +347,8 @@ struct Result {
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -805,17 +806,16 @@ int run(Options &options) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
||||
// and must have compute capability at least 90.
|
||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
||||
// CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
@ -861,8 +861,12 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
|
||||
run(options);
|
||||
#else
|
||||
std::cerr
|
||||
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl;
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -14,8 +14,8 @@ cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
|
||||
### Minimum software
|
||||
|
||||
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
|
||||
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
|
||||
CUDA graph APIs.
|
||||
This example specifically requires CUDA Toolkit 12.8 or newer, since that is the first version
|
||||
supporting the Blackwell architecture.
|
||||
|
||||
### Hardware / driver settings
|
||||
|
||||
|
||||
1192
examples/88_hopper_fmha/88_hopper_fmha.cu
Normal file
1192
examples/88_hopper_fmha/88_hopper_fmha.cu
Normal file
File diff suppressed because it is too large
Load Diff
50
examples/88_hopper_fmha/CMakeLists.txt
Normal file
50
examples/88_hopper_fmha/CMakeLists.txt
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
88_hopper_fmha
|
||||
88_hopper_fmha.cu
|
||||
)
|
||||
|
||||
if(NOT WIN32 AND NOT CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
set_property(
|
||||
SOURCE 88_hopper_fmha.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math"
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
88_hopper_fmha_fp8
|
||||
88_hopper_fmha.cu
|
||||
)
|
||||
|
||||
target_compile_definitions(
|
||||
88_hopper_fmha_fp8
|
||||
PRIVATE FP8)
|
||||
|
||||
endif()
|
||||
77
examples/88_hopper_fmha/README.md
Normal file
77
examples/88_hopper_fmha/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
# CUTLASS Hopper FMHA Example
|
||||
|
||||
This sample showcases how to implement fused multi-head attention (FMHA) using
|
||||
CUTLASS for the NVIDIA Hopper architecture. At its heart, the forward pass of
|
||||
FMHA is a GEMM-online softmax-GEMM fusion, whereas the backward pass is a slightly
|
||||
more complex structure (basically, a GEMM-softmax-2xGEMM-2xGEMM fusion).
|
||||
For more information please refer to the [Flash Attention 3 paper](https://arxiv.org/abs/2407.08608).
|
||||
|
||||
The forward pass kernel supports head dims 32, 64, 128, and 256 for fp16 and bf16 input data types,
|
||||
and head dims 128, and 256 for fp8.
|
||||
All kernels use the Tensor Memory Accelerator for loads.
|
||||
Kernels with head dims 128 and 256 have warp-specialized cooperative schedules.
|
||||
|
||||
Backward pass kernels (fp16 only) support head dims 32, 64, and 128, and all support
|
||||
warp-specialized cooperative schedules.
|
||||
|
||||
## Customization
|
||||
|
||||
### Mask Fusion
|
||||
|
||||
Similar to the [Blackwell FMHA example](../77_blackwell_fmha/README.md), attention masks such as
|
||||
causal masking can be fused into the kernel. To modify the code for such fusions,
|
||||
`collective/fmha_fusion.hpp` provides the easiest customization point.
|
||||
The `before_softmax` function is called with the accumulator of the first GEMM and the logical
|
||||
positions of those elements. It is well-suited for applying masks or activations.
|
||||
|
||||
### MHA Variants
|
||||
|
||||
Using CuTe, it is easy to represent the various attention variants.
|
||||
Where regular multi-head attention's layout for the head dimension is (numHeads:headStride),
|
||||
for single-head attention it is simply (1:0) everywhere,
|
||||
for GQA it is normal in Q and (numHeads/numGroups,numGroups:headStride,0) in KV,
|
||||
and for MQA it is normal for Q and (numHeads:0) in KV.
|
||||
As such, beyond general stride handling, no additional work is needed to support these,
|
||||
and the example will just demonstrate regular multi-head attention.
|
||||
|
||||
### FP8
|
||||
|
||||
The warp-specialized forward kernel supports FP8 computation with both FP32 and FP16
|
||||
accumulation for the Q*K product. They can be enabled in the runner by defining FP8.
|
||||
|
||||
## Performance
|
||||
Forward pass kernels can generally come close to that of FA3, but backward pass
|
||||
kernels are more limited in performance and are not expected to reach the same level of performance
|
||||
as FA3.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
@ -0,0 +1,863 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
#include "../collective/fmha_collective_load.hpp"
|
||||
#include "../collective/fmha_collective_softmax.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
typename Element_,
|
||||
typename ElementAccumulator_,
|
||||
typename TileShape_, // BlockQO, BlockKV, BlockHead
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBwdMainloopTmaWarpSpecialized {
|
||||
|
||||
using Element = Element_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TileShape = TileShape_;
|
||||
|
||||
static constexpr bool kIsPersistent = false;
|
||||
|
||||
static const int NumLoadWarpGroups = 1;
|
||||
static constexpr int NumMmaWarpGroups = 2;
|
||||
static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups;
|
||||
static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */;
|
||||
|
||||
static const int kOuterLoads = 2;
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static_assert(StagesQ::value >= 2);
|
||||
static_assert(Stages::value >= 2 * NumMmaWarpGroups);
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using TileShapeNM = Shape< // (N,M,D)
|
||||
decltype(tuple_element_t<1, TileShape>{} / Int<NumMmaWarpGroups>{}),
|
||||
tuple_element_t<0, TileShape>,
|
||||
tuple_element_t<2, TileShape>>;
|
||||
|
||||
using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M)
|
||||
|
||||
using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N)
|
||||
|
||||
using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeNM, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeND, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeND, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
|
||||
using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment, // from smem, might matter (?)
|
||||
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeMD, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using TiledMmaNM = typename CollectiveMmaNM::TiledMma;
|
||||
using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma;
|
||||
using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{}));
|
||||
using TiledMmaND = TiledMmaND_RS;
|
||||
using TiledMmaMD = typename CollectiveMmaMD::TiledMma;
|
||||
|
||||
using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB;
|
||||
using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA;
|
||||
using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA;
|
||||
using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB;
|
||||
|
||||
//using SmemLayoutDQ = Layout<
|
||||
// Shape<
|
||||
// tuple_element_t<0, TileShapeMD>,
|
||||
// Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>,
|
||||
// _2
|
||||
// >,
|
||||
// Stride<
|
||||
// _4,
|
||||
// Stride<decltype(tuple_element_t<0, TileShapeMD>{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>,
|
||||
// decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
|
||||
// >>;
|
||||
|
||||
using SmemLayoutDQ_0 = Layout<
|
||||
Shape<
|
||||
tuple_element_t<0, TileShapeMD>,
|
||||
tuple_element_t<1, TileShapeMD>,
|
||||
_2
|
||||
>,
|
||||
Stride<
|
||||
tuple_element_t<1, TileShapeMD>,
|
||||
_1,
|
||||
decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
|
||||
>>;
|
||||
|
||||
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>());
|
||||
using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutDQ = SmemLayoutDQ_1;
|
||||
|
||||
|
||||
using PipelineDQ = cutlass::PipelineAsync<2>;
|
||||
|
||||
|
||||
using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int<NumMmaWarpGroups>{}));
|
||||
|
||||
using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom<Element>{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{}));
|
||||
using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB;
|
||||
|
||||
using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB;
|
||||
using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB;
|
||||
|
||||
using SmemLayoutLSE = Layout<Shape<tuple_element_t<1, TileShapeNM>, Int<StageCount>>>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
using TileShapePV = TileShapeND; // To work with the kernel level
|
||||
using TiledMmaPV = TiledMmaND;
|
||||
|
||||
static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator);
|
||||
static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
|
||||
|
||||
struct SharedStorage {
|
||||
// One for each consumer WG
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutKp>> smem_kp;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
};
|
||||
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
|
||||
|
||||
// Loaded by producer, consumed by both WGs
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQp>> smem_qp;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDOp>> smem_dop;
|
||||
};
|
||||
|
||||
// Accumulated into by both consumers, potentially loaded, potentially written
|
||||
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutDQ>> smem_dq;
|
||||
|
||||
union {
|
||||
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_lse;
|
||||
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_sumOdO;
|
||||
};
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, int, int, _1> dQ;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, int, int, _1> dK;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, int, int, _1> dV;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, int, int, _1> dDO;
|
||||
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<int, int, _1> dLSE;
|
||||
const ElementAccumulator* ptr_sum_OdO;
|
||||
cute::tuple<int, int, _1> dSumOdO;
|
||||
|
||||
ElementAccumulator* ptr_dQ;
|
||||
cute::tuple<int, int, int, _1> dDQ;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaNM::Params::TMA_B;
|
||||
using TMA_K = typename CollectiveMmaNM::Params::TMA_A;
|
||||
using TMA_V = typename CollectiveMmaNM::Params::TMA_A;
|
||||
using TMA_DO = typename CollectiveMmaNM::Params::TMA_B;
|
||||
|
||||
using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{})));
|
||||
using TMA_ODO = TMA_LSE;
|
||||
|
||||
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{})));
|
||||
|
||||
using LoadQ = CollectiveLoadTma<
|
||||
LoadKind::kBwdM,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutQ,
|
||||
TMA_Q
|
||||
>;
|
||||
|
||||
using LoadK = CollectiveLoadTma<
|
||||
LoadKind::kBwdN,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutK,
|
||||
TMA_K
|
||||
>;
|
||||
|
||||
using LoadV = CollectiveLoadTma<
|
||||
LoadKind::kBwdN,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutV,
|
||||
TMA_V
|
||||
>;
|
||||
|
||||
using LoadDO = CollectiveLoadTma<
|
||||
LoadKind::kBwdM,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutDO,
|
||||
TMA_DO
|
||||
>;
|
||||
|
||||
using LoadLSE = CollectiveLoadTma<
|
||||
LoadKind::kBwdScalar,
|
||||
MainloopPipeline,
|
||||
ElementAccumulator,
|
||||
SmemLayoutLSE,
|
||||
TMA_LSE
|
||||
>;
|
||||
|
||||
using LoadODO = CollectiveLoadTma<
|
||||
LoadKind::kBwdScalar,
|
||||
MainloopPipeline,
|
||||
ElementAccumulator,
|
||||
SmemLayoutLSE,
|
||||
TMA_ODO
|
||||
>;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
TMA_DO tma_load_do;
|
||||
|
||||
TMA_LSE tma_load_lse;
|
||||
TMA_ODO tma_load_odo;
|
||||
|
||||
TMA_DQ tma_red_dq;
|
||||
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
};
|
||||
|
||||
static_assert(size(TiledMmaNM{}) == size(TiledMmaND{}));
|
||||
static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{}));
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||
return true
|
||||
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||
;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||
auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
|
||||
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK)));
|
||||
auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ)));
|
||||
auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
|
||||
typename CollectiveMmaNM::Arguments {
|
||||
args.ptr_K, dK,
|
||||
args.ptr_Q, dQ,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV)));
|
||||
auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO)));
|
||||
auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
|
||||
typename CollectiveMmaNM::Arguments {
|
||||
args.ptr_V, dV,
|
||||
args.ptr_dO, dDO,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
|
||||
TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{}));
|
||||
TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{}));
|
||||
|
||||
TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{}));
|
||||
|
||||
return Params{
|
||||
params_nm_kq.tma_load_b,
|
||||
params_nm_kq.tma_load_a,
|
||||
params_nm_vdo.tma_load_a,
|
||||
params_nm_vdo.tma_load_b,
|
||||
tma_load_lse, tma_load_odo,
|
||||
tma_red_dq,
|
||||
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size)))
|
||||
};
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
auto
|
||||
get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) {
|
||||
return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<bool kLoadOuter, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_kv_maybe_q(
|
||||
int block_rank_in_cluster,
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner,
|
||||
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
// Load pattern:
|
||||
// K0 V0 K1 V1
|
||||
// Q0 DO0 Q1 DO1 Q2 DO2 ...
|
||||
// K0 Q0 V0 K1 DO0 V1 ...
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
int outer_tile_count = NumMmaWarpGroups;
|
||||
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
|
||||
|
||||
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
|
||||
auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count);
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
|
||||
|
||||
LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do};
|
||||
auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse};
|
||||
auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO};
|
||||
auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
outer_tile_count *= 2; // K & V
|
||||
inner_tile_count *= 4; // Q & dO & LSE & sumOdO
|
||||
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 4;
|
||||
++inner_tile_iter;
|
||||
}
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
|
||||
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (! kLoadOuter) {
|
||||
if (do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
|
||||
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
|
||||
if constexpr (kLoadOuter) {
|
||||
while (outer_tile_count > 0) {
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (inner_tile_count > 0) {
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 4;
|
||||
++inner_tile_iter;
|
||||
}
|
||||
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
|
||||
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_maybe_q(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
// Load pattern:
|
||||
// K0 V0 K1 V1
|
||||
// Q0 DO0 Q1 DO1 Q2 DO2 ...
|
||||
// K0 Q0 V0 K1 DO0 V1 ...
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
int outer_tile_count = NumMmaWarpGroups;
|
||||
|
||||
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||
|
||||
outer_tile_count *= 2; // K & V
|
||||
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
|
||||
if (do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
|
||||
while (outer_tile_count > 0) {
|
||||
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
|
||||
CUTLASS_DEVICE void
|
||||
reduce(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer,
|
||||
SharedStorage& storage)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size));
|
||||
Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{});
|
||||
Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord));
|
||||
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
|
||||
|
||||
auto block_tma = params.tma_red_dq.get_slice(_0{});
|
||||
|
||||
Tensor tDQsDQ = block_tma.partition_S(sDQ);
|
||||
Tensor tDQgDQ = block_tma.partition_D(gDQ);
|
||||
|
||||
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
|
||||
int g_index = 0;
|
||||
|
||||
auto smem_pipe_release_reducer = smem_pipe_read_reducer;
|
||||
bool first = true;
|
||||
while (inner_tile_count > 0) {
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 1;
|
||||
++g_index;
|
||||
}
|
||||
if (inner_tile_count == 0) break;
|
||||
|
||||
pipeline_reducer.consumer_wait(smem_pipe_read_reducer);
|
||||
if (lane_predicate == 1) {
|
||||
tma_store_wait<1>();
|
||||
}
|
||||
if (! first) {
|
||||
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
|
||||
++smem_pipe_release_reducer;
|
||||
} else {
|
||||
first = false;
|
||||
}
|
||||
if (lane_predicate == 1) {
|
||||
copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index));
|
||||
tma_store_arrive();
|
||||
}
|
||||
++smem_pipe_read_reducer;
|
||||
--inner_tile_count;
|
||||
++g_index;
|
||||
}
|
||||
if (lane_predicate) {
|
||||
tma_store_wait<0>();
|
||||
}
|
||||
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
|
||||
++smem_pipe_release_reducer;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
|
||||
CUTLASS_DEVICE auto
|
||||
compute(
|
||||
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
|
||||
Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner,
|
||||
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer,
|
||||
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
|
||||
SharedStorage& storage,
|
||||
MathWgOrderBarrier& math_wg_order_barrier)
|
||||
{
|
||||
TiledMmaND tiled_mma_nd;
|
||||
|
||||
Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
|
||||
clear(acc_DV);
|
||||
|
||||
Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
|
||||
clear(acc_DK);
|
||||
|
||||
int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineState smem_pipe_release_inner = smem_pipe_read_inner;
|
||||
|
||||
pipeline_outer.consumer_wait(smem_pipe_read_outer);
|
||||
PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer;
|
||||
++smem_pipe_read_outer;
|
||||
pipeline_outer.consumer_wait(smem_pipe_read_outer);
|
||||
PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer;
|
||||
|
||||
int inner_tile_count = get_inner_tile_count(wg_coord, problem_size);
|
||||
|
||||
TiledMmaNM tiled_mma_nm;
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx);
|
||||
Tensor tSsK = thr_mma_nm.partition_A(sK);
|
||||
Tensor tSsQ = thr_mma_nm.partition_B(sQ);
|
||||
Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK);
|
||||
Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ);
|
||||
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{});
|
||||
|
||||
Tensor tDPsV = thr_mma_nm.partition_A(sV);
|
||||
Tensor tDPsDO = thr_mma_nm.partition_B(sDO);
|
||||
Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV);
|
||||
Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO);
|
||||
|
||||
auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{});
|
||||
Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp);
|
||||
Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO);
|
||||
|
||||
|
||||
Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{});
|
||||
Tensor tDK_sQ = thr_mma_nd.partition_B(sQp);
|
||||
Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ);
|
||||
|
||||
|
||||
int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0);
|
||||
|
||||
TiledMmaMD tiled_mma_md;
|
||||
auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx);
|
||||
Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{});
|
||||
Tensor tDQsDS = thr_mma_md.partition_A(sDS);
|
||||
Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS);
|
||||
Tensor tDQrDS = tDQrDS_full(_,_,_,_);
|
||||
Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{});
|
||||
Tensor tDQsK = thr_mma_md.partition_B(sKp);
|
||||
Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK);
|
||||
|
||||
Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
|
||||
Tensor tSsLSE = thr_mma_nm.partition_C(sLSE);
|
||||
|
||||
Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
|
||||
Tensor tDPsODO = thr_mma_nm.partition_C(sODO);
|
||||
|
||||
Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{}));
|
||||
Tensor tScS = thr_mma_nm.partition_C(cS);
|
||||
int n_block = get<1>(wg_coord);
|
||||
tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{});
|
||||
|
||||
|
||||
// Transpose
|
||||
Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS))));
|
||||
Tensor sDSp = sDSp_full(_,_,_);
|
||||
Tensor tDPsDS = thr_mma_nm.partition_C(sDSp);
|
||||
|
||||
auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx);
|
||||
Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp);
|
||||
|
||||
Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp);
|
||||
|
||||
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
|
||||
auto tDQsDQ_full = thr_mma_md.partition_C(sDQ);
|
||||
|
||||
|
||||
auto smem_pipe_read_k_other = smem_pipe_read_k;
|
||||
smem_pipe_read_k_other.advance(2);
|
||||
|
||||
int k_index = 0;
|
||||
|
||||
while (inner_tile_count > 0) {
|
||||
while (inner_tile_count > 0) {
|
||||
if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||
break;
|
||||
}
|
||||
inner_tile_count -= 1;
|
||||
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
|
||||
k_index += 1;
|
||||
}
|
||||
if (inner_tile_count == 0) break;
|
||||
|
||||
pipeline_inner.consumer_wait(smem_pipe_read_inner);
|
||||
PipelineState smem_pipe_read_q = smem_pipe_read_inner;
|
||||
++smem_pipe_read_inner;
|
||||
PipelineState smem_pipe_read_do = smem_pipe_read_inner;
|
||||
++smem_pipe_read_inner;
|
||||
|
||||
// GEMM KQ -> S
|
||||
Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
|
||||
|
||||
warpgroup_fence_operand(acc_S);
|
||||
warpgroup_arrive();
|
||||
gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline_inner.consumer_wait(smem_pipe_read_do);
|
||||
|
||||
// GEMM VdO -> dP
|
||||
Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
|
||||
|
||||
warpgroup_fence_operand(acc_DP);
|
||||
warpgroup_arrive();
|
||||
gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
Tensor reg_LSE = make_fragment_like<ElementAccumulator>(acc_S);
|
||||
for (int i = 0; i < size(reg_LSE); i++) {
|
||||
reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i);
|
||||
}
|
||||
|
||||
Tensor reg_ODO = make_fragment_like<ElementAccumulator>(acc_S);
|
||||
if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) {
|
||||
for (int i = 0; i < size(reg_ODO); i++) {
|
||||
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_S);
|
||||
|
||||
math_wg_order_barrier.wait();
|
||||
// Compute S -> P
|
||||
Fusion{}.before_softmax(acc_S, tScS, problem_size);
|
||||
auto acc_P = make_fragment_like<ElementAccumulator>(acc_S);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_P); i++) {
|
||||
acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i));
|
||||
}
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) {
|
||||
for (int i = 0; i < size(reg_ODO); i++) {
|
||||
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_DP);
|
||||
|
||||
// Compute dP P -> dS
|
||||
auto acc_DS = make_fragment_like<Element>(acc_DP);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_DS); i++) {
|
||||
// We could move the scale out and into the respective epilogues (or a final scaling step)
|
||||
acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i));
|
||||
}
|
||||
|
||||
// GEMM PdO -> dV
|
||||
auto op_P = make_acc_into_op<Element>(acc_P, typename TiledMmaND::LayoutA_TV{});
|
||||
warpgroup_fence_operand(acc_DV);
|
||||
warpgroup_fence_operand(op_P);
|
||||
warpgroup_arrive();
|
||||
cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Store dS to smem dS'
|
||||
if (wg_idx == 0) math_wg_order_barrier.wait();
|
||||
|
||||
auto recast_bits = [](auto sz, auto t) {
|
||||
return recast<uint_bit_t<decltype(sz)::value>>(t);
|
||||
};
|
||||
auto tDPsDS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, tDPsDS);
|
||||
auto acc_DS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, acc_DS);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_DS_v); i++) {
|
||||
tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i);
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
if (wg_idx == 0) math_wg_order_barrier.arrive();
|
||||
|
||||
// GEMM dS Q -> dK
|
||||
if (wg_idx == 1) {
|
||||
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// GEMM dS' K -> dQ
|
||||
Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{}));
|
||||
|
||||
warpgroup_fence_operand(acc_DQ);
|
||||
warpgroup_arrive();
|
||||
gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ);
|
||||
cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
warpgroup_arrive();
|
||||
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_DQ);
|
||||
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
|
||||
auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index());
|
||||
|
||||
// Store dQ to smem dQ'
|
||||
// Invoke TMA reduce on dQ'
|
||||
using Vec = uint_bit_t<sizeof_bits_v<ElementAccumulator> * 2>;
|
||||
auto tDQsDQ_v = recast<Vec>(tDQsDQ);
|
||||
auto acc_DQ_v = recast<Vec>(acc_DQ);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_DQ_v); i++) {
|
||||
tDQsDQ_v(i) = acc_DQ_v(i);
|
||||
}
|
||||
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
|
||||
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
|
||||
++smem_pipe_write_reducer;
|
||||
} else {
|
||||
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
warpgroup_arrive();
|
||||
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
warpgroup_wait<1>();
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
|
||||
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
|
||||
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
|
||||
++smem_pipe_write_reducer;
|
||||
}
|
||||
|
||||
--inner_tile_count;
|
||||
|
||||
pipeline_inner.consumer_release(smem_pipe_release_inner);
|
||||
++smem_pipe_release_inner;
|
||||
pipeline_inner.consumer_release(smem_pipe_release_inner);
|
||||
++smem_pipe_release_inner;
|
||||
|
||||
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
|
||||
k_index += 1;
|
||||
}
|
||||
|
||||
pipeline_outer.consumer_release(smem_pipe_read_k);
|
||||
pipeline_outer.consumer_release(smem_pipe_read_outer);
|
||||
pipeline_reducer.producer_tail(smem_pipe_write_reducer);
|
||||
++smem_pipe_read_outer;
|
||||
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_DK);
|
||||
warpgroup_fence_operand(acc_DV);
|
||||
|
||||
return make_tuple(acc_DK, acc_DV);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
140
examples/88_hopper_fmha/collective/fmha_collective_load.hpp
Normal file
140
examples/88_hopper_fmha/collective/fmha_collective_load.hpp
Normal file
@ -0,0 +1,140 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
enum class LoadKind {
|
||||
kQ, kK, kV,
|
||||
kBwdN, kBwdM, kBwdScalar
|
||||
};
|
||||
|
||||
template<
|
||||
LoadKind kKind,
|
||||
class Pipeline,
|
||||
class Element,
|
||||
class SmemLayout,
|
||||
class TMA
|
||||
>
|
||||
struct CollectiveLoadTma {
|
||||
|
||||
using Params = TMA;
|
||||
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayout>>;
|
||||
using PipelineState = typename cutlass::PipelineState<Pipeline::Stages>;
|
||||
|
||||
Params const& params;
|
||||
Pipeline& pipeline;
|
||||
SharedStorage& storage;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage)
|
||||
: params(params), pipeline(pipeline), storage(storage) {}
|
||||
|
||||
template<class ProblemSize, class TileShape, class BlockCoord>
|
||||
CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape,
|
||||
BlockCoord const& blk_coord, int loop_count
|
||||
) {
|
||||
using X = Underscore;
|
||||
if constexpr (kKind == LoadKind::kK) {
|
||||
Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return gK;
|
||||
} else if constexpr (kKind == LoadKind::kQ) {
|
||||
Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout());
|
||||
} else if constexpr (kKind == LoadKind::kV) {
|
||||
Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord));
|
||||
return gV;
|
||||
} else if constexpr (kKind == LoadKind::kBwdN) {
|
||||
Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout());
|
||||
} else if constexpr (kKind == LoadKind::kBwdM) {
|
||||
Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||
return g;
|
||||
} else if constexpr (kKind == LoadKind::kBwdScalar) {
|
||||
Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size));
|
||||
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, X>{});
|
||||
Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord));
|
||||
return g;
|
||||
}
|
||||
}
|
||||
|
||||
template<class ClusterRank, class ProblemSize, class TileShape, class BlockCoord>
|
||||
CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster,
|
||||
ProblemSize const& problem_size, TileShape const& tile_shape,
|
||||
BlockCoord const& block_coord, int loop_count
|
||||
) {
|
||||
Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count);
|
||||
Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{});
|
||||
|
||||
auto block_tma = params.get_slice(block_rank_in_cluster);
|
||||
Tensor ts = block_tma.partition_D(s);
|
||||
Tensor tg = block_tma.partition_S(g);
|
||||
|
||||
return make_tuple(tg, ts);
|
||||
}
|
||||
|
||||
template<bool kAdvanceIterator=true, bool kAdvancePipe=true, bool kAcquireBarrier=true, class TileIterator, class State>
|
||||
CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state,
|
||||
PipelineState& smem_pipe_write,
|
||||
int lane_predicate, int& tile_count, uint16_t mcast_mask = 0
|
||||
) {
|
||||
if ((lane_predicate == 1) && (tile_count > 0)) {
|
||||
if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write);
|
||||
using BarrierType = typename Pipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
if constexpr (kKind == LoadKind::kBwdScalar) {
|
||||
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index()));
|
||||
} else {
|
||||
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index()));
|
||||
}
|
||||
if constexpr (kAdvancePipe) ++smem_pipe_write;
|
||||
if constexpr (kAdvanceIterator) ++tile_iter;
|
||||
}
|
||||
--tile_count;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
305
examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp
Normal file
305
examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp
Normal file
@ -0,0 +1,305 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
class ElementAccumulator,
|
||||
class Fusion,
|
||||
class Params
|
||||
>
|
||||
struct CollectiveSoftmax {
|
||||
Params const& params;
|
||||
CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {}
|
||||
|
||||
using SumType = float;
|
||||
using MaxType = ElementAccumulator;
|
||||
|
||||
template<class AccPV, class TiledMmaPV>
|
||||
CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) {
|
||||
Tensor s_max = make_fragment_like<MaxType>(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout())));
|
||||
Tensor a_sum = make_fragment_like<SumType>(s_max);
|
||||
return make_tuple(s_max, a_sum);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE float overload_exp2(float f) {
|
||||
return ::exp2f(f);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) {
|
||||
auto a = f.raw();
|
||||
decltype(a) d;
|
||||
asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a));
|
||||
return cutlass::half_t::bitcast(d);
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE float overload_max(float a, float b) {
|
||||
return ::max(a, b);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) {
|
||||
return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) {
|
||||
return f.to_half();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE float overload_to_native(float f) {
|
||||
return f;
|
||||
}
|
||||
|
||||
template<class AccQK, class TiledMmaQK, class CountQK, class State, class ProblemShape>
|
||||
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) {
|
||||
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
// Linear reduction is faster for the first iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = acc_qk_mn(i, 0);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < size<1>(acc_qk_mn); j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
|
||||
}
|
||||
}
|
||||
});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
|
||||
MaxType scale_max = scale * local_max;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
|
||||
}
|
||||
}
|
||||
|
||||
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
|
||||
CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
|
||||
|
||||
if constexpr (kUseFusion) {
|
||||
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||
}
|
||||
|
||||
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
|
||||
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
Tensor s_max_prev = make_fragment_like(s_max);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max_prev(i) = s_max(i);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
// Linear reduction is faster here, as well
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||
}
|
||||
}
|
||||
// reduce max
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride<r>(reduction_target_qk) * j));
|
||||
}
|
||||
}
|
||||
});
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
|
||||
float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i);
|
||||
float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2);
|
||||
a_sum(i) *= scale;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
|
||||
acc_pv_mn(i, j) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class AccQK_MN, class State>
|
||||
CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) {
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<0>(acc_qk_mn); j++) {
|
||||
float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j);
|
||||
float scale_max = params.scale_softmax_log2 * local_max;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < size<1>(acc_qk_mn); k++) {
|
||||
acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max);
|
||||
a_sum(j) += acc_qk_mn(j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
|
||||
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
|
||||
|
||||
if constexpr (kUseFusion) {
|
||||
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||
}
|
||||
|
||||
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
|
||||
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
Tensor s_max_prev = make_fragment_like(s_max);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||
s_max_prev(i) = s_max(i);
|
||||
|
||||
// Linear reduction is faster here, as well
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||
}
|
||||
// reduce max
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
|
||||
}
|
||||
});
|
||||
|
||||
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
|
||||
MaxType scale_max = scale * local_max;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
|
||||
}
|
||||
|
||||
MaxType s_max_cur = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||
SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale);
|
||||
a_sum(i) *= scale_pv;
|
||||
|
||||
using ElementPV = typename AccPV::value_type;
|
||||
ElementPV scale_pv_ele = static_cast<ElementPV>(scale_pv);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
|
||||
acc_pv_mn(i, j) *= scale_pv_ele;
|
||||
}
|
||||
a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class State, class AccPV, class TiledMmaPV>
|
||||
CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) {
|
||||
auto& s_max = get<0>(state);
|
||||
auto& a_sum = get<1>(state);
|
||||
|
||||
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
auto reduction_target = reduction_target_n(tiled_mma_pv);
|
||||
constexpr int red_rank = decltype(rank(reduction_target))::value;
|
||||
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 1; j < shape<r>(reduction_target); j *= 2) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
|
||||
a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride<r>(reduction_target) * j);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||
|
||||
Tensor lse = make_fragment_like(a_sum);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(acc_mn); i++) {
|
||||
float sum = a_sum(i);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum);
|
||||
lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum);
|
||||
float scale = params.rp_dropout * inv_sum;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(acc_mn); j++) {
|
||||
acc_mn(i, j) *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
return lse;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
526
examples/88_hopper_fmha/collective/fmha_collective_tma.hpp
Normal file
526
examples/88_hopper_fmha/collective/fmha_collective_tma.hpp
Normal file
@ -0,0 +1,526 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
#include "../collective/fmha_collective_load.hpp"
|
||||
#include "../collective/fmha_collective_softmax.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
using cutlass::fmha::kernel::Tag;
|
||||
using cutlass::fmha::kernel::find_option_t;
|
||||
|
||||
template<
|
||||
typename Element_,
|
||||
typename ElementAccumulator_,
|
||||
typename TileShape_, // BlockQO, BlockKV, BlockHead
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaMainloopTma {
|
||||
|
||||
using Element = Element_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TileShape = TileShape_;
|
||||
|
||||
// Options
|
||||
using kClusterM = find_option_t<Tag::kClusterM, Int<1>, Options...>;
|
||||
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<4>, Options...>::value;
|
||||
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<1>, Options...>::value;
|
||||
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||
using ClusterShape = Shape<kClusterM, _1, _1>;
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using TileShapeQK = TileShape;
|
||||
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
|
||||
|
||||
using LayoutQKV = cute::tuple<int, _1, cute::tuple<int, int>>;
|
||||
using LayoutQ = LayoutQKV;
|
||||
using LayoutK = LayoutQKV;
|
||||
using LayoutV = LayoutQKV;
|
||||
|
||||
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, LayoutQ, Alignment,
|
||||
Element, LayoutK, Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapeQK, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
// the stride for A does not matter since we do not load from smem at all
|
||||
Element, LayoutK, Alignment,
|
||||
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
|
||||
ElementAccumulator,
|
||||
TileShapePV, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
|
||||
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
|
||||
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
|
||||
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
using TileShapeOut = TileShapePV;
|
||||
using TiledMmaOut = TiledMmaPV;
|
||||
using ElementOut = ElementAccumulator;
|
||||
|
||||
struct SharedStorage {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
};
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
LayoutQ dQ;
|
||||
const Element* ptr_K;
|
||||
LayoutK dK;
|
||||
const Element* ptr_V;
|
||||
LayoutV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
float rp_dropout;
|
||||
};
|
||||
|
||||
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kQ,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutQ,
|
||||
TMA_Q
|
||||
>;
|
||||
|
||||
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kK,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutK,
|
||||
TMA_K
|
||||
>;
|
||||
|
||||
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kV,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutV,
|
||||
TMA_V
|
||||
>;
|
||||
|
||||
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{});
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||
return true
|
||||
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||
;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||
|
||||
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
args.ptr_Q, args.dQ,
|
||||
args.ptr_K, args.dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
args.ptr_K, args.dK, // never used, dummy
|
||||
args.ptr_V, select<1,0,2>(args.dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b,
|
||||
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
|
||||
1.0f
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
compute(
|
||||
int block_rank_in_cluster,
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& storage)
|
||||
{
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
[[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
|
||||
|
||||
|
||||
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
|
||||
|
||||
// Set predicate for the lowest lane_id in the warp
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue TmaLoads (Prologue fetches)
|
||||
if (warp_idx == 0) {
|
||||
auto q_tile_iter = cute::make_coord_iterator(1);
|
||||
int q_tile_count = 1;
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
|
||||
// Loop over K elems
|
||||
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
|
||||
|
||||
int k_tile_count_tma = 2 * fusion_tile_count;
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
if (warp_idx == 0 && lane_predicate == 1) {
|
||||
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < StageCount; i++) {
|
||||
if (i % 2 == 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
} else {
|
||||
load_v.template step<true>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TiledMmaQK tiled_mma_qk;
|
||||
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup QK
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
|
||||
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
// Prepare: MMA PV
|
||||
TiledMmaPV tiled_mma_pv;
|
||||
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup PV
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
pipeline_q.consumer_wait(smem_pipe_read_q);
|
||||
|
||||
// mapping into QK accumulator
|
||||
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
|
||||
Tensor tPcP = thr_mma_qk.partition_C(cP);
|
||||
int m_block = get<0>(blk_coord);
|
||||
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
|
||||
|
||||
// Allocate PV acc
|
||||
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
|
||||
|
||||
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulator, Fusion, decltype(params)> softmax{params};
|
||||
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
|
||||
|
||||
if (true)
|
||||
{
|
||||
--k_tile_count;
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
|
||||
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
|
||||
|
||||
Tensor acc_qk_fixed = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})));
|
||||
|
||||
Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
acc_qk_input(i) = static_cast<Element>(acc_qk(i));
|
||||
}
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
//
|
||||
// Advance the pipe
|
||||
//
|
||||
|
||||
// Advance consumer pipeline
|
||||
++smem_pipe_read;
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
softmax.template step_interleave_begin<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
|
||||
|
||||
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
|
||||
|
||||
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<2>(tOrV); i++) {
|
||||
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
|
||||
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
|
||||
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
|
||||
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(acc_qk_element_mk); j++) {
|
||||
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
|
||||
}
|
||||
warpgroup_arrive();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(tOrV); j++) {
|
||||
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
|
||||
|
||||
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
|
||||
|
||||
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<2>(tOrV); i++) {
|
||||
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
|
||||
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
|
||||
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
|
||||
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size(acc_qk_element_mk); j++) {
|
||||
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
|
||||
}
|
||||
warpgroup_arrive();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < size<1>(tOrV); j++) {
|
||||
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
if (warp_idx == 0) {
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||
}
|
||||
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
|
||||
|
||||
return make_tuple(acc_pv, lse);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
@ -0,0 +1,560 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
#include "../collective/fmha_collective_load.hpp"
|
||||
#include "../collective/fmha_collective_softmax.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
using cutlass::fmha::kernel::Tag;
|
||||
using cutlass::fmha::kernel::find_option_t;
|
||||
|
||||
template<
|
||||
class Element_,
|
||||
class ElementAccumulatorQK_,
|
||||
class ElementAccumulatorPV_,
|
||||
class TileShape_, // SeqQ, SeqKV, Head
|
||||
class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches)
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaMainloopTmaWarpSpecialized {
|
||||
|
||||
using Element = Element_;
|
||||
using ElementAccumulatorQK = ElementAccumulatorQK_;
|
||||
using ElementAccumulatorPV = ElementAccumulatorPV_;
|
||||
using TileShape = TileShape_;
|
||||
|
||||
using LayoutQ = LayoutQ_;
|
||||
using LayoutK = LayoutK_;
|
||||
using LayoutV = LayoutV_;
|
||||
|
||||
// Options
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
|
||||
static constexpr bool kIsMainloopLocked = find_option_t<Tag::kIsMainloopLocked, false_type, Options...>::value;
|
||||
|
||||
static constexpr int NumLoadWarpGroups = 1;
|
||||
static constexpr int NumMmaWarpGroups = find_option_t<Tag::kNumMmaWarpGroups, Int<2>, Options...>::value;
|
||||
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<5>, Options...>::value;
|
||||
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<NumMmaWarpGroups>, Options...>::value;
|
||||
|
||||
static const int kOuterLoads = 1;
|
||||
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static_assert(StagesQ::value >= NumMmaWarpGroups);
|
||||
static_assert(Stages::value >= 2);
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using TileShapeQK = Shape<
|
||||
decltype(tuple_element_t<0, TileShape>{} / Int<NumMmaWarpGroups>{}),
|
||||
tuple_element_t<1, TileShape>,
|
||||
tuple_element_t<2, TileShape>>;
|
||||
|
||||
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
|
||||
|
||||
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Element, LayoutQ, Alignment,
|
||||
Element, LayoutK, Alignment,
|
||||
ElementAccumulatorQK,
|
||||
TileShapeQK, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
// the stride for A does not matter since we do not load from smem at all
|
||||
Element, LayoutK, Alignment,
|
||||
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
|
||||
ElementAccumulatorPV,
|
||||
TileShapePV, ClusterShape, Stages,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||
|
||||
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
|
||||
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
|
||||
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
|
||||
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
|
||||
static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element);
|
||||
|
||||
using TileShapeOut = TileShapePV;
|
||||
using TiledMmaOut = TiledMmaPV;
|
||||
using ElementOut = ElementAccumulatorPV;
|
||||
|
||||
struct SharedStorage {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
union {
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
};
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
LayoutQ dQ;
|
||||
const Element* ptr_K;
|
||||
LayoutK dK;
|
||||
const Element* ptr_V;
|
||||
LayoutV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
float rp_dropout;
|
||||
};
|
||||
|
||||
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kQ,
|
||||
MainloopPipelineQ,
|
||||
Element,
|
||||
SmemLayoutQ,
|
||||
TMA_Q
|
||||
>;
|
||||
|
||||
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kK,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutK,
|
||||
TMA_K
|
||||
>;
|
||||
|
||||
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
|
||||
cutlass::fmha::collective::LoadKind::kV,
|
||||
MainloopPipeline,
|
||||
Element,
|
||||
SmemLayoutV,
|
||||
TMA_V
|
||||
>;
|
||||
|
||||
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||
return true
|
||||
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||
;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||
|
||||
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
args.ptr_Q, args.dQ,
|
||||
args.ptr_K, args.dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
args.ptr_K, args.dK, // never used, dummy
|
||||
args.ptr_V, select<1,0,2>(args.dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b,
|
||||
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
|
||||
1.0f
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<bool kLoadQ, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_kv_maybe_q(
|
||||
int block_rank_in_cluster,
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline, PipelineState& smem_pipe_write,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
if (lane_predicate == 1) {
|
||||
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
|
||||
[[maybe_unused]] int q_tile_count = NumMmaWarpGroups;
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
|
||||
int k_tile_count = 2 * fusion_tile_count;
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
|
||||
|
||||
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
|
||||
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
|
||||
|
||||
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
|
||||
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
|
||||
|
||||
if constexpr (kLoadQ) {
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (kLoadQ) {
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
|
||||
if constexpr (! kLoadQ) {
|
||||
if (do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
|
||||
if constexpr (kLoadQ) {
|
||||
while (q_tile_count > 0) {
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0) {
|
||||
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||
CUTLASS_DEVICE void
|
||||
load_maybe_q(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& storage,
|
||||
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
|
||||
|
||||
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) {
|
||||
int count = 1;
|
||||
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count);
|
||||
if (q_tile_count == 0 && do_barrier) {
|
||||
load_warp_barrier.arrive();
|
||||
load_warp_barrier.wait(/*phase=*/ 0);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
|
||||
CUTLASS_DEVICE void
|
||||
reduce(
|
||||
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
|
||||
SharedStorage& storage)
|
||||
{ /* no-op */ }
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
|
||||
CUTLASS_DEVICE auto
|
||||
compute(
|
||||
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
|
||||
Params const& params, ProblemShape const& problem_size,
|
||||
MainloopPipeline& pipeline, PipelineState& smem_pipe_read,
|
||||
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipelineReducer&, PipelineStateReducer&,
|
||||
SharedStorage& storage,
|
||||
MathWgOrderBarrier& math_wg_order_barrier)
|
||||
{
|
||||
int thread_idx = int(threadIdx.x);
|
||||
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
|
||||
|
||||
TiledMmaQK tiled_mma_qk;
|
||||
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup QK
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
|
||||
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
// Prepare: MMA PV
|
||||
TiledMmaPV tiled_mma_pv;
|
||||
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
|
||||
|
||||
// Mainloop setup PV
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
|
||||
|
||||
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
pipeline_q.consumer_wait(smem_pipe_read_q);
|
||||
|
||||
// mapping into QK accumulator
|
||||
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
|
||||
Tensor tPcP = thr_mma_qk.partition_C(cP);
|
||||
int m_block = get<0>(wg_coord);
|
||||
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
|
||||
|
||||
// Allocate PV acc
|
||||
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
|
||||
|
||||
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulatorQK, Fusion, decltype(params)> softmax{params};
|
||||
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
|
||||
|
||||
if (true)
|
||||
{
|
||||
--k_tile_count;
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
++smem_pipe_read;
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
|
||||
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
|
||||
|
||||
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
// Advance consumer pipeline
|
||||
++smem_pipe_read;
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0)
|
||||
{
|
||||
--k_tile_count;
|
||||
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
|
||||
softmax.template step<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
|
||||
|
||||
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read, tok);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0)
|
||||
{
|
||||
--k_tile_count;
|
||||
|
||||
// Allocate QK acc
|
||||
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read);
|
||||
|
||||
// MMA QK
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_arrive();
|
||||
|
||||
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_qk);
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
//if constexpr (kIsPersistent)
|
||||
// if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q);
|
||||
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
|
||||
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
|
||||
|
||||
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||
|
||||
pipeline.consumer_wait(smem_pipe_read, tok);
|
||||
|
||||
// MMA PV
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
warpgroup_fence_operand(acc_qk_fixed);
|
||||
warpgroup_arrive();
|
||||
|
||||
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||
warpgroup_commit_batch();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
++smem_pipe_read;
|
||||
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||
}
|
||||
|
||||
if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q);
|
||||
|
||||
// Wait for the pipeline MMAs to drain
|
||||
warpgroup_wait<0>();
|
||||
warpgroup_fence_operand(acc_pv);
|
||||
|
||||
if (kIsPersistent) pipeline.consumer_release(smem_pipe_release);
|
||||
++smem_pipe_release;
|
||||
|
||||
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
|
||||
|
||||
return make_tuple(acc_pv, lse);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
245
examples/88_hopper_fmha/collective/fmha_common.hpp
Normal file
245
examples/88_hopper_fmha/collective/fmha_common.hpp
Normal file
@ -0,0 +1,245 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
constexpr int rA = decltype(rank(tA))::value;
|
||||
constexpr int rB = decltype(rank(tB))::value;
|
||||
constexpr int rC = decltype(rank(tC))::value;
|
||||
if constexpr (rA == 2 && rB == 2 && rC == 1) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<1>(tA); k_block++) {
|
||||
cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC);
|
||||
atom.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
static_assert(rA == 3 && rB == 3 && rC == 3);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
|
||||
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
|
||||
atom.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
atom.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
gemm_reset_zero_acc(atom, tA, tB, tC);
|
||||
}
|
||||
|
||||
template<typename T, typename Fn>
|
||||
CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) {
|
||||
if constexpr (decltype(size(t) % _2{} == _0{})::value) {
|
||||
auto partial = make_tensor<typename T::value_type>(size(t) / _2{});
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(partial); i++) {
|
||||
partial(i) = fn(t(i), t(i + size(partial)));
|
||||
}
|
||||
return reduce(partial, fn);
|
||||
} else {
|
||||
auto result = t(_0{});
|
||||
CUTE_UNROLL
|
||||
for (int i = 1; i < size(t); i++) {
|
||||
result = fn(result, t(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
struct fmha_max {
|
||||
CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); }
|
||||
};
|
||||
|
||||
template<typename Threshold, typename Source, typename Reference>
|
||||
inline auto __device__ constexpr layout_separate(Threshold const& thr,
|
||||
Source const& src, Reference const& ref) {
|
||||
auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
|
||||
if constexpr(decltype(r < thr)::value) {
|
||||
return s;
|
||||
} else {
|
||||
return make_layout(_1{}, _0{});
|
||||
}
|
||||
}));
|
||||
auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
|
||||
if constexpr(decltype(r >= thr)::value) {
|
||||
return s;
|
||||
} else {
|
||||
return make_layout(_1{}, _0{});
|
||||
}
|
||||
}));
|
||||
return make_layout(lt, ge);
|
||||
}
|
||||
|
||||
template<typename TiledMma, typename Acc>
|
||||
inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) {
|
||||
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||
get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{}));
|
||||
auto V_M = get<0>(separated);
|
||||
auto V_N = get<1>(separated);
|
||||
return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc)));
|
||||
}
|
||||
|
||||
template<typename TiledMma, typename Acc>
|
||||
inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) {
|
||||
return layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||
get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{}));
|
||||
}
|
||||
|
||||
template<typename TiledMma, typename Acc>
|
||||
inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) {
|
||||
return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout()));
|
||||
}
|
||||
|
||||
template<typename TiledMma>
|
||||
inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) {
|
||||
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||
make_layout(shape<0>(typename TiledMma::LayoutC_TV{})),
|
||||
stride<0>(typename TiledMma::LayoutC_TV{}));
|
||||
return get<1>(separated);
|
||||
}
|
||||
|
||||
|
||||
template<template<cute::GMMA::Major, cute::GMMA::Major, cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
|
||||
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<tA, tB, sA, sB>> const& tiled_mma) {
|
||||
using Atom = cute::MMA_Atom<Primitive<tA, tB, sA, sB>>;
|
||||
using ElementA = typename Atom::ValTypeA;
|
||||
using ElementB = typename Atom::ValTypeB;
|
||||
using ElementC = typename Atom::ValTypeC;
|
||||
using Shape_MNK = typename Atom::Shape_MNK;
|
||||
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
|
||||
return cute::MMA_Atom<RS>{};
|
||||
}
|
||||
|
||||
template<template<cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
|
||||
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<sA, sB>> const& tiled_mma) {
|
||||
using Atom = cute::MMA_Atom<Primitive<sA, sB>>;
|
||||
using ElementA = typename Atom::ValTypeA;
|
||||
using ElementB = typename Atom::ValTypeB;
|
||||
using ElementC = typename Atom::ValTypeC;
|
||||
using Shape_MNK = typename Atom::Shape_MNK;
|
||||
constexpr auto tA = cute::GMMA::Major::K;
|
||||
constexpr auto tB = cute::GMMA::Major::K;
|
||||
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
|
||||
return cute::MMA_Atom<RS>{};
|
||||
}
|
||||
|
||||
template<class Atom, class... Args>
|
||||
CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA<Atom, Args...> const& tiled_mma) {
|
||||
return cute::TiledMMA<decltype(convert_to_gmma_rs(Atom{})), Args...>{};
|
||||
}
|
||||
|
||||
template<typename CLayout, typename AValueShape>
|
||||
CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) {
|
||||
return make_layout(
|
||||
make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))),
|
||||
make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c))));
|
||||
}
|
||||
|
||||
template<class Layout, class Stages = _1>
|
||||
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
|
||||
return composition(layout, make_tuple(_, _, make_layout(stages)));
|
||||
}
|
||||
|
||||
template<class Element, class Accumulator, class OperandLayout_TV>
|
||||
CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) {
|
||||
Tensor operand = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv)));
|
||||
Tensor operand_as_acc = make_tensor(operand.data(), acc.layout());
|
||||
|
||||
cute::copy(acc, operand_as_acc);
|
||||
|
||||
if constexpr (sizeof(Element) == 1) {
|
||||
|
||||
// 00 11 22 33 00 11 22 33 acc layout
|
||||
// 00 00 11 11 22 22 33 33 operand layout
|
||||
// BB AA AA BB AA BB BB AA conflict-free exchange pattern
|
||||
// 16-bit exchange; so process two at a time potentially
|
||||
int tid = threadIdx.x % 4;
|
||||
auto values_u32 = recast<uint32_t>(operand);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(values_u32); n++) {
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < size<2>(values_u32); k++) {
|
||||
CUTE_UNROLL
|
||||
for (int ii = 0; ii < 8; ii += 4) {
|
||||
|
||||
uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k);
|
||||
uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k);
|
||||
|
||||
// step A:
|
||||
// t 1 v 0 -> t 0 v 1
|
||||
// t 2 v 0 -> t 1 v 0
|
||||
// t 0 v 1 -> t 2 v 0
|
||||
// t 3 v 1 -> t 3 v 1
|
||||
|
||||
int v_to_send = tid == 1 || tid == 2 ? 0 : 1;
|
||||
int v_to_recv = v_to_send;
|
||||
int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF;
|
||||
|
||||
uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
|
||||
|
||||
values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4);
|
||||
|
||||
// step B:
|
||||
// t 0 v 0 -> t 0 v 0
|
||||
// t 3 v 0 -> t 1 v 1
|
||||
// t 1 v 1 -> t 2 v 1
|
||||
// t 2 v 1 -> t 3 v 0
|
||||
|
||||
v_to_send = 1 - v_to_send;
|
||||
v_to_recv = 1 - v_to_recv;
|
||||
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF;
|
||||
|
||||
uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
|
||||
|
||||
values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4);
|
||||
|
||||
values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410);
|
||||
values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return operand;
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
156
examples/88_hopper_fmha/collective/fmha_epilogue.hpp
Normal file
156
examples/88_hopper_fmha/collective/fmha_epilogue.hpp
Normal file
@ -0,0 +1,156 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "../collective/fmha_common.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape_WG>
|
||||
struct FmhaFwdEpilogue {
|
||||
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
|
||||
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> dO;
|
||||
|
||||
ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
|
||||
|
||||
typename CollectiveEpilogueTMA::Params epilogue_TMA;
|
||||
};
|
||||
|
||||
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage;
|
||||
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
|
||||
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
|
||||
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
|
||||
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO};
|
||||
return Params{
|
||||
args.ptr_LSE, args.dLSE,
|
||||
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace)
|
||||
};
|
||||
}
|
||||
|
||||
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
|
||||
CUTLASS_DEVICE void operator()(
|
||||
TileShape const& tile_shape, BlkCoord const& blk_coord,
|
||||
ResultTuple const& result, TiledMma const& tiled_mma,
|
||||
ProblemShape const& problem_size, Params const& params,
|
||||
LoadPipeline epi_load_pipeline,
|
||||
TensorStorage& epi_tensor_storage)
|
||||
{
|
||||
using X = Underscore;
|
||||
|
||||
auto acc = get<0>(result);
|
||||
auto lse = get<1>(result);
|
||||
|
||||
auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
|
||||
|
||||
int seqlen_q = get<2>(problem_size);
|
||||
int num_batch = get<0>(problem_size);
|
||||
int num_heads = get<1>(problem_size);
|
||||
// Epilogue for lse
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE),
|
||||
make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)),
|
||||
make_stride(_1{}, _0{}, get<1>(params.dLSE)));
|
||||
Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{});
|
||||
Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord));
|
||||
Tensor tOgLSE = thr_mma.partition_C(gLSE);
|
||||
Tensor cO = make_identity_tensor(take<0,2>(tile_shape));
|
||||
Tensor tOcO = thr_mma.partition_C(cO);
|
||||
if (get<1>(tOcO(_0{})) == 0) {
|
||||
auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout()));
|
||||
auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout()));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tOgLSE_mn); i++) {
|
||||
if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) {
|
||||
tOgLSE_mn(i, _0{}) = lse(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
|
||||
CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage);
|
||||
|
||||
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
epilogue_tma.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||
acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||
epi_tensor_storage
|
||||
);
|
||||
|
||||
epilogue_tma.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
157
examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp
Normal file
157
examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp
Normal file
@ -0,0 +1,157 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "../collective/fmha_epilogue.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape_WG>
|
||||
struct FmhaBwdEpilogueKV {
|
||||
|
||||
static constexpr int Alignment = 16 / sizeof(Element);
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_K;
|
||||
cute::tuple<int, int, int, cute::_1> dK;
|
||||
|
||||
Element* ptr_V;
|
||||
cute::tuple<int, int, int, _1> dV;
|
||||
};
|
||||
|
||||
//using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90Compute<cutlass::first, Element, ElementAccumulator, RoundStyle>,
|
||||
cutlass::epilogue::fusion::Sm90AccFetch
|
||||
>;
|
||||
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
struct Params {
|
||||
typename CollectiveEpilogueTMA::Params epilogue_K;
|
||||
typename CollectiveEpilogueTMA::Params epilogue_V;
|
||||
};
|
||||
|
||||
|
||||
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2];
|
||||
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
|
||||
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
|
||||
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
|
||||
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK),
|
||||
make_stride(get<0>(args.dK), get<1>(args.dK)));
|
||||
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV),
|
||||
make_stride(get<0>(args.dV), get<1>(args.dV)));
|
||||
|
||||
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK};
|
||||
typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV};
|
||||
return Params{
|
||||
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr),
|
||||
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr)
|
||||
};
|
||||
}
|
||||
|
||||
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
|
||||
CUTLASS_DEVICE void operator()(
|
||||
TileShape const& tile_shape, BlkCoord const& blk_coord,
|
||||
ResultTuple const& result, TiledMma const& tiled_mma,
|
||||
ProblemShape const& problem_size, Params const& params,
|
||||
LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage)
|
||||
{
|
||||
auto acc_k = get<0>(result);
|
||||
auto acc_v = get<1>(result);
|
||||
|
||||
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _,
|
||||
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||
|
||||
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]};
|
||||
CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]};
|
||||
|
||||
{
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
epilogue_k.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||
acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||
epi_tensor_storage[0]
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
epilogue_v.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||
acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||
epi_tensor_storage[1]
|
||||
);
|
||||
|
||||
epilogue_k.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||
);
|
||||
|
||||
epilogue_v.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
283
examples/88_hopper_fmha/collective/fmha_fusion.hpp
Normal file
283
examples/88_hopper_fmha/collective/fmha_fusion.hpp
Normal file
@ -0,0 +1,283 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct DefaultFusion {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return ceil_div(get<3>(problem_size), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
|
||||
) {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
struct ResidualFusion : DefaultFusion {
|
||||
|
||||
using Base = DefaultFusion;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||
// the remaining elements out from softmax.
|
||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||
// issues as they are transparently taken care of by TMA and the
|
||||
// epilogue, if it is instantiated with predication support.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<1>(pos) >= get<3>(problem_size)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalFusion : DefaultFusion {
|
||||
|
||||
using Base = DefaultFusion;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is what we demonstrate here
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<0>(pos) < get<1>(pos)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class Base>
|
||||
struct FusionBwdAdapter {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size));
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
auto index_base = index_qk(_0{});
|
||||
auto index_shape = shape(index_qk);
|
||||
auto index_stride = transform_leaf(stride(index_qk), [](auto elem) {
|
||||
if constexpr (is_scaled_basis<decltype(elem)>::value) {
|
||||
if constexpr(decltype(elem.mode() == _0{})::value) {
|
||||
return ScaledBasis<decltype(elem.value()), 1>(elem.value());
|
||||
} else {
|
||||
return ScaledBasis<decltype(elem.value()), 0>(elem.value());
|
||||
}
|
||||
} else {
|
||||
return elem;
|
||||
}
|
||||
});
|
||||
auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride));
|
||||
Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
bool is_contributing(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FusionBwdAdapter<CausalFusion> {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
return get<2>(problem_size) / get<0>(TileShape{});
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void before_softmax(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size
|
||||
|
||||
) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<1>(pos) < get<0>(pos)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
bool is_contributing(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size
|
||||
) {
|
||||
int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape);
|
||||
int min_k = get<1>(blk_coord) * get<1>(tile_shape);
|
||||
return min_k <= max_q;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
278
examples/88_hopper_fmha/device/device_universal.hpp
Normal file
278
examples/88_hopper_fmha/device/device_universal.hpp
Normal file
@ -0,0 +1,278 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class Kernel_>
|
||||
class Universal {
|
||||
public:
|
||||
using Kernel = Kernel_;
|
||||
|
||||
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
|
||||
|
||||
/// Argument structure: User API
|
||||
using Arguments = typename Kernel::Arguments;
|
||||
/// Argument structure: Kernel API
|
||||
using Params = typename Kernel::Params;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (Kernel::can_implement(args)) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
return Kernel::get_grid_shape(params);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
cutlass::arch::synclog_setup();
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
299
examples/88_hopper_fmha/device/fmha_device_bwd.hpp
Normal file
299
examples/88_hopper_fmha/device/fmha_device_bwd.hpp
Normal file
@ -0,0 +1,299 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "../device/device_universal.hpp"
|
||||
#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp"
|
||||
#include "../collective/fmha_fusion.hpp"
|
||||
#include "../collective/fmha_epilogue_bwd.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_tile_scheduler.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape, class Fusion, class... Options>
|
||||
class FmhaBwd {
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
cute::tuple<int, int, int, int, int> problem_size;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, int, int, cute::_1> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, int, int, cute::_1> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, int, int, cute::_1> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, int, int, cute::_1> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<int, int, _1> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dV;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
using OperationSumOdO = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>>;
|
||||
using OperationConvert = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>>;
|
||||
|
||||
using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized<
|
||||
Element, ElementAccumulator, TileShape,
|
||||
cutlass::fmha::collective::FusionBwdAdapter<Fusion>, Options...>;
|
||||
|
||||
using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV<Element, ElementAccumulator, typename Mainloop::TileShapePV>;
|
||||
|
||||
using Operation = cutlass::device::Universal<
|
||||
cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<
|
||||
Mainloop,
|
||||
Epilogue,
|
||||
cutlass::fmha::kernel::TileSchedulerBwdAdapter<cutlass::fmha::kernel::IndividualTileScheduler>, Options...>>;
|
||||
|
||||
struct Params {
|
||||
OperationSumOdO op_sum_OdO;
|
||||
Operation op;
|
||||
OperationConvert op_convert;
|
||||
ElementAccumulator* dQ_acc;
|
||||
size_t dQ_acc_size;
|
||||
};
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
|
||||
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) {
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(H*Q, Q, _1{});
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_size,
|
||||
args.ptr_O, args.stride_O,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
dest, stride_sum_OdO
|
||||
};
|
||||
}
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{});
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_size,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV
|
||||
};
|
||||
}
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<int, int, _1> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, int, int, _1> const& stride_dQ = {}
|
||||
) {
|
||||
return typename Operation::Arguments{
|
||||
args.problem_size,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
args.ptr_LSE, args.stride_LSE,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = OperationConvert::can_implement(to_convert_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = Operation::can_implement(to_bwd_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
// OdO vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
// FP32 versions of outputs that are churned (start off with Q only)
|
||||
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||
params_.dQ_acc = dQ_acc;
|
||||
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
|
||||
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO);
|
||||
auto args_convert = to_convert_arguments(args, dQ_acc);
|
||||
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
|
||||
params_.op_convert.initialize(args_convert, nullptr, stream);
|
||||
auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ);
|
||||
params_.op.initialize(args_bwd, nullptr, stream);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [B, H, Q, K, D] = args.problem_size;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
return initialize_split(args, dQ_acc, sum_OdO, stream);
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
|
||||
|
||||
Status result = Status::kSuccess;
|
||||
result = params.op_sum_OdO.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
|
||||
if (cuda_result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
result = params.op.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = params.op_convert.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
158
examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp
Normal file
158
examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp
Normal file
@ -0,0 +1,158 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../collective/fmha_collective_tma.hpp"
|
||||
#include "../collective/fmha_collective_tma_warpspecialized.hpp"
|
||||
#include "../collective/fmha_epilogue.hpp"
|
||||
#include "../kernel/fmha_kernel_tma.hpp"
|
||||
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<
|
||||
class Element_,
|
||||
class ElementAccumulatorQK_,
|
||||
class ElementAccumulatorPV_,
|
||||
class TileShape_, // BlockQO, BlockKV, BlockHead
|
||||
class LayoutQ_,
|
||||
class LayoutK_,
|
||||
class LayoutV_,
|
||||
class Fusion,
|
||||
class DispatchPolicy,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAccumulator,
|
||||
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder<
|
||||
Element,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTma,
|
||||
Options...
|
||||
> {
|
||||
|
||||
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma<Element, ElementAccumulator, TileShape, Fusion, Options...>;
|
||||
|
||||
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
|
||||
Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::FmhaKernelTma<CollectiveMainloop, CollectiveEpilogue, Options...>;
|
||||
};
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAccumulatorQK,
|
||||
class ElementAccumulatorPV,
|
||||
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||
class LayoutQ,
|
||||
class LayoutK,
|
||||
class LayoutV,
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder<
|
||||
Element,
|
||||
ElementAccumulatorQK,
|
||||
ElementAccumulatorPV,
|
||||
TileShape,
|
||||
LayoutQ,
|
||||
LayoutK,
|
||||
LayoutV,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Options...
|
||||
> {
|
||||
|
||||
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized<
|
||||
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||
TileShape, LayoutQ, LayoutK, LayoutV,
|
||||
Fusion, Options...>;
|
||||
|
||||
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
|
||||
Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>;
|
||||
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
|
||||
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<CollectiveMainloop, CollectiveEpilogue, TileScheduler, Options...>;
|
||||
};
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAccumulatorQK,
|
||||
class ElementAccumulatorPV,
|
||||
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||
class LayoutQ,
|
||||
class LayoutK,
|
||||
class LayoutV,
|
||||
class Fusion,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaBuilder<
|
||||
Element,
|
||||
ElementAccumulatorQK,
|
||||
ElementAccumulatorPV,
|
||||
TileShape,
|
||||
LayoutQ,
|
||||
LayoutK,
|
||||
LayoutV,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
|
||||
Options...
|
||||
> {
|
||||
using Kernel = typename FmhaBuilder<
|
||||
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||
TileShape,
|
||||
LayoutQ, LayoutK, LayoutV,
|
||||
Fusion,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Options...,
|
||||
Option<Tag::kIsPersistent, true_type>,
|
||||
Option<Tag::kLoadsQSeparately, true_type>
|
||||
>::Kernel;
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
143
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
143
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
@ -0,0 +1,143 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAccumulator>
|
||||
struct FmhaKernelBwdConvert {
|
||||
|
||||
struct Arguments {
|
||||
tuple<int, int, int, int, int> problem_size;
|
||||
|
||||
const ElementAccumulator* ptr_src_dQ;
|
||||
tuple<int, int, int, _1> stride_src_dQ;
|
||||
const ElementAccumulator* ptr_src_dK;
|
||||
tuple<int, int, int, _1> stride_src_dK;
|
||||
const ElementAccumulator* ptr_src_dV;
|
||||
tuple<int, int, int, _1> stride_src_dV;
|
||||
|
||||
Element* ptr_dest_dQ;
|
||||
tuple<int, int, int, _1> stride_dest_dQ;
|
||||
Element* ptr_dest_dK;
|
||||
tuple<int, int, int, _1> stride_dest_dK;
|
||||
Element* ptr_dest_dV;
|
||||
tuple<int, int, int, _1> stride_dest_dV;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static const int kBlockSeq = 8;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kNumThreadsD = 16;
|
||||
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 4;
|
||||
|
||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<4>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template<class StrideSrc, class StrideDest>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||
auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y;
|
||||
|
||||
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||
if (idx_s >= count) continue;
|
||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src);
|
||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
ElementAccumulator value_src[kElementsPerLoad];
|
||||
Element value_dest[kElementsPerLoad];
|
||||
|
||||
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAccumulator> * kElementsPerLoad>;
|
||||
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
value_dest[v] = value_src[v];
|
||||
}
|
||||
|
||||
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
if (params.ptr_src_dQ != nullptr) {
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dK != nullptr) {
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dV != nullptr) {
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
134
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
134
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
@ -0,0 +1,134 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAccumulator>
|
||||
struct FmhaKernelBwdSumOdO {
|
||||
|
||||
struct Arguments {
|
||||
cute::tuple<int, int, int, int, int> problem_size;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, int, int, cute::_1> stride_O;
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, int, int, cute::_1> stride_dO;
|
||||
|
||||
ElementAccumulator* ptr_sum_OdO;
|
||||
cute::tuple<int, int, _1> stride_sum_OdO;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kBlockQ = 16;
|
||||
|
||||
static const int kNumThreadsD = 8;
|
||||
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 2;
|
||||
|
||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<4>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||
if (idx_q >= get<2>(params.problem_size)) continue;
|
||||
ElementAccumulator acc = 0;
|
||||
auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O);
|
||||
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
Element value_O[kElementsPerLoad];
|
||||
Element value_dO[kElementsPerLoad];
|
||||
|
||||
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
|
||||
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
acc += value_O[v] * value_dO[v];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < kNumThreadsD; i *= 2) {
|
||||
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
*ptr_sum_OdO_bhq = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
222
examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Normal file
222
examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Normal file
@ -0,0 +1,222 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "../kernel/fmha_tile_scheduler.hpp"
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaKernelTma {
|
||||
|
||||
// Options
|
||||
static constexpr int kBlocksPerSM = find_option_t<Tag::kBlocksPerSM, Int<2>, Options...>::value;
|
||||
|
||||
using Element = typename CollectiveMainloop::Element;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
|
||||
using TileScheduler = IndividualTileScheduler;
|
||||
|
||||
using StagesQ = typename CollectiveMainloop::StagesQ;
|
||||
using Stages = typename CollectiveMainloop::Stages;
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
|
||||
using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ;
|
||||
using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK;
|
||||
|
||||
struct SharedStorage {
|
||||
union {
|
||||
typename CollectiveMainloop::SharedStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
};
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage;
|
||||
alignas(16) PipelineStorage pipeline_storage;
|
||||
alignas(16) PipelineStorageQ pipeline_storage_q;
|
||||
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Arguments mainloop;
|
||||
typename CollectiveEpilogue::Arguments epilogue;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = kBlocksPerSM;
|
||||
static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return Params{
|
||||
args.problem_size,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
|
||||
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
// Shared memory.
|
||||
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q
|
||||
pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer;
|
||||
pipeline_params_q.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV
|
||||
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{});
|
||||
MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{});
|
||||
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer;
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// State variables used for iterating the circular buffer
|
||||
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
|
||||
PipelineState smem_pipe_read;
|
||||
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer blocks in the Cluster
|
||||
// and to finish smem init
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
}
|
||||
else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
auto result = collective_mainloop.compute(
|
||||
block_rank_in_cluster,
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline, smem_pipe_read, smem_pipe_write,
|
||||
pipeline_q, smem_pipe_read_q, smem_pipe_write_q,
|
||||
storage.mainloop
|
||||
);
|
||||
|
||||
CollectiveEpilogue epilogue;
|
||||
epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord,
|
||||
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||
params.problem_size, params.epilogue,
|
||||
epi_load_pipeline, storage.epilogue);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -0,0 +1,418 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "../kernel/fmha_options.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class TileScheduler,
|
||||
class... Options
|
||||
>
|
||||
struct FmhaKernelTmaWarpSpecialized {
|
||||
|
||||
// Options
|
||||
static constexpr bool kIsEpilogueLocked = find_option_t<Tag::kIsEpilogueLocked, false_type, Options...>::value;
|
||||
static constexpr bool kLoadsQSeparately = find_option_t<Tag::kLoadsQSeparately, false_type, Options...>::value;
|
||||
|
||||
|
||||
static const int NumLoadWarpGroups = 1;
|
||||
static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups;
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline;
|
||||
using MainloopPipelineReducer = cutlass::PipelineAsync<2>;
|
||||
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
|
||||
StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
struct TensorStorageStruct {
|
||||
typename CollectiveMainloop::SharedStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
|
||||
};
|
||||
union TensorStorageUnion {
|
||||
typename CollectiveMainloop::SharedStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
|
||||
};
|
||||
using TensorStorage = std::conditional_t<CollectiveMainloop::kIsPersistent, TensorStorageStruct, TensorStorageUnion>;
|
||||
|
||||
struct SharedStorage {
|
||||
TensorStorage tensors;
|
||||
|
||||
using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage;
|
||||
using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage;
|
||||
using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage;
|
||||
|
||||
alignas(16) PipelineStorageInner pipeline_storage_inner;
|
||||
alignas(16) PipelineStorageOuter pipeline_storage_outer;
|
||||
alignas(16) PipelineStorageReducer pipeline_storage_reducer;
|
||||
|
||||
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
|
||||
alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier;
|
||||
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Arguments mainloop;
|
||||
typename CollectiveEpilogue::Arguments epilogue;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_size;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
using PipelineParamsInner = typename MainloopPipelineInner::Params;
|
||||
using PipelineStateInner = typename cutlass::PipelineState<MainloopPipelineInner::Stages>;
|
||||
using PipelineParamsOuter = typename MainloopPipelineOuter::Params;
|
||||
using PipelineStateOuter = typename cutlass::PipelineState<MainloopPipelineOuter::Stages>;
|
||||
using PipelineParamsReducer = typename MainloopPipelineReducer::Params;
|
||||
using PipelineStateReducer = typename cutlass::PipelineState<MainloopPipelineReducer::Stages>;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8;
|
||||
static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return Params{
|
||||
args.problem_size,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
|
||||
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2,
|
||||
Consumer2 = 3,
|
||||
Consumer3 = 4,
|
||||
};
|
||||
enum class ProducerWarpRole {
|
||||
LoadKV = 1,
|
||||
Reducer = 0,
|
||||
MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it)
|
||||
MainloopEpilogue = 3,
|
||||
};
|
||||
|
||||
static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV;
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
// Shared memory.
|
||||
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
int lane_idx = cutlass::canonical_lane_idx();
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
auto warp_group_role = WarpGroupRole(warp_group_idx);
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0;
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
PipelineParamsOuter pipeline_params_outer;
|
||||
pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes;
|
||||
pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ);
|
||||
pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParamsInner pipeline_params_inner;
|
||||
pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes;
|
||||
pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV);
|
||||
pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParamsReducer pipeline_params_reducer;
|
||||
pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
||||
pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp;
|
||||
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) {
|
||||
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) {
|
||||
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) {
|
||||
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1 ||
|
||||
warp_group_role == WarpGroupRole::Consumer2 ||
|
||||
warp_group_role == WarpGroupRole::Consumer3
|
||||
) {
|
||||
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer;
|
||||
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer;
|
||||
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer;
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
|
||||
MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{});
|
||||
MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{});
|
||||
MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer);
|
||||
|
||||
// State variables used for iterating the circular buffer
|
||||
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
|
||||
PipelineStateInner smem_pipe_read_inner;
|
||||
PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state<MainloopPipelineInner>();
|
||||
|
||||
PipelineStateOuter smem_pipe_read_outer;
|
||||
PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state<MainloopPipelineOuter>();
|
||||
|
||||
PipelineStateReducer smem_pipe_read_reducer;
|
||||
PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state<MainloopPipelineReducer>();
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id = consumer_warp_group_idx;
|
||||
params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Epilogue Load pipeline
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
|
||||
|
||||
if constexpr (kLoadsQSeparately) {
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp);
|
||||
}
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer blocks in the Cluster
|
||||
// and to finish smem init
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
}
|
||||
else {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
if (producer_warp_role == ProducerWarpRole::LoadKV) {
|
||||
bool do_barrier = kLoadsQSeparately;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
collective_mainloop.template load_kv_maybe_q<!kLoadsQSeparately>(
|
||||
block_rank_in_cluster,
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline_inner, smem_pipe_write_inner,
|
||||
pipeline_outer, smem_pipe_write_outer,
|
||||
storage.tensors.mainloop,
|
||||
storage.load_warp_barrier, do_barrier
|
||||
);
|
||||
do_barrier = false;
|
||||
}
|
||||
}
|
||||
else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) {
|
||||
bool do_barrier = true;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
collective_mainloop.load_maybe_q(
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline_outer, smem_pipe_write_outer,
|
||||
storage.tensors.mainloop,
|
||||
storage.load_warp_barrier, do_barrier
|
||||
);
|
||||
do_barrier = false;
|
||||
}
|
||||
} else if (producer_warp_role == ProducerWarpRole::Reducer) {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
collective_mainloop.reduce(
|
||||
blk_coord, params.mainloop, params.problem_size,
|
||||
pipeline_reducer, smem_pipe_read_reducer,
|
||||
storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (
|
||||
warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1 ||
|
||||
warp_group_role == WarpGroupRole::Consumer2 ||
|
||||
warp_group_role == WarpGroupRole::Consumer3
|
||||
) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto wg_coord = blk_coord;
|
||||
|
||||
constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads;
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||
smem_pipe_read_outer.advance(0 * kOuterLoads);
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
smem_pipe_read_outer.advance(1 * kOuterLoads);
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||
smem_pipe_read_outer.advance(2 * kOuterLoads);
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||
smem_pipe_read_outer.advance(3 * kOuterLoads);
|
||||
}
|
||||
|
||||
constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1;
|
||||
auto& wg_block = get<wg_dim>(wg_coord);
|
||||
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 0;
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 1;
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 2;
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||
wg_block = NumMmaWarpGroups * wg_block + 3;
|
||||
}
|
||||
|
||||
auto result = collective_mainloop.compute(
|
||||
blk_coord, wg_coord,
|
||||
params.mainloop, params.problem_size,
|
||||
pipeline_inner, smem_pipe_read_inner,
|
||||
pipeline_outer, smem_pipe_read_outer,
|
||||
pipeline_reducer, smem_pipe_write_reducer,
|
||||
storage.tensors.mainloop,
|
||||
math_wg_order_barrier
|
||||
);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0));
|
||||
}
|
||||
if constexpr (NumMmaWarpGroups >= 2) {
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1));
|
||||
}
|
||||
}
|
||||
if constexpr (NumMmaWarpGroups >= 3) {
|
||||
if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2));
|
||||
}
|
||||
}
|
||||
if constexpr (NumMmaWarpGroups >= 4) {
|
||||
if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait();
|
||||
|
||||
CollectiveEpilogue epilogue;
|
||||
epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord,
|
||||
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||
params.problem_size, params.epilogue,
|
||||
epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]);
|
||||
|
||||
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -28,51 +28,56 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
||||
#include <cute/numeric/integral_constant.hpp> // cute::true_type
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cute
|
||||
{
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template <class T>
|
||||
struct ConstantTensor
|
||||
{
|
||||
template <class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const&
|
||||
operator()(Coords const&...) const {
|
||||
return val_;
|
||||
}
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
struct find_option;
|
||||
|
||||
T val_;
|
||||
template<auto kTag, typename Default>
|
||||
struct find_option<kTag, Default> {
|
||||
using option_value = Default;
|
||||
};
|
||||
|
||||
struct TrivialPredTensor
|
||||
{
|
||||
template <class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
true_type
|
||||
operator()(Coords const&...) const {
|
||||
return {};
|
||||
}
|
||||
template<auto kTag, typename Default, typename Option, typename... Options>
|
||||
struct find_option<kTag, Default, Option, Options...> :
|
||||
std::conditional_t<
|
||||
Option::tag == kTag,
|
||||
Option,
|
||||
find_option<kTag, Default, Options...>
|
||||
>
|
||||
{};
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
|
||||
|
||||
enum class Tag {
|
||||
kIsPersistent,
|
||||
kNumMmaWarpGroups,
|
||||
kLoadsQSeparately,
|
||||
|
||||
kIsMainloopLocked,
|
||||
kIsEpilogueLocked,
|
||||
|
||||
kStagesQ,
|
||||
kStagesKV,
|
||||
|
||||
kEpilogueKind,
|
||||
|
||||
kBlocksPerSM,
|
||||
kClusterM,
|
||||
|
||||
kAccQK
|
||||
};
|
||||
|
||||
template <class Fn>
|
||||
struct FunctionPredTensor
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
FunctionPredTensor(Fn const& fn) : fn_(fn) {}
|
||||
|
||||
template <class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator()(Coords const&... coords) const {
|
||||
return fn_(coords...);
|
||||
}
|
||||
|
||||
Fn const& fn_;
|
||||
template<auto kTag, class Value>
|
||||
struct Option {
|
||||
static constexpr auto tag = kTag;
|
||||
using option_value = Value;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
} // namespace cutlass::fmha::kernel
|
||||
204
examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
204
examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
@ -0,0 +1,204 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||
{
|
||||
using namespace cute;
|
||||
dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size));
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct PersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_h;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||
{
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||
int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size);
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) },
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, bidh;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Base>
|
||||
struct TileSchedulerBwdAdapter {
|
||||
|
||||
using Params = typename Base::Params;
|
||||
|
||||
Base base_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileSchedulerBwdAdapter(Params const& params) : base_(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||
{
|
||||
using namespace cute;
|
||||
return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape));
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return Base::get_grid_shape(params);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return base_.is_valid();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return select<1,0,2>(base_.get_block_coord());
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileSchedulerBwdAdapter& operator++() {
|
||||
++base_;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
357
examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp
Normal file
357
examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp
Normal file
@ -0,0 +1,357 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
class TensorDQ, /* class TensorDK, class TensorDV, */
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dQ_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_K] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
|
||||
}
|
||||
mDQ(idx_Q, idx_D, idx_L) = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/* class TensorDQ, */ class TensorDK, /* class TensorDV, */
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dK_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
|
||||
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
|
||||
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
ElementAccumulator acc_dov = 0;
|
||||
ElementAccumulator acc_doo = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
|
||||
}
|
||||
mDK(idx_K, idx_D, idx_L) = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/* class TensorDQ, class TensorDK, */ class TensorDV,
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_bwd_reference_dV_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
|
||||
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
|
||||
|
||||
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||
ElementAccumulator acc_qk = 0;
|
||||
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||
}
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc_qk;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
acc_qk = frag(0);
|
||||
|
||||
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||
acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L);
|
||||
}
|
||||
mDV(idx_K, idx_D, idx_L) = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference_dQ(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_bwd_reference_dQ_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDQ, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference_dK(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_bwd_reference_dK_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDK, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
/** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference_dV(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
/** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_bwd_reference_dV_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDV, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ, class TensorK, class TensorV,
|
||||
class TensorO, class TensorLSE, class TensorDO,
|
||||
class TensorDQ, class TensorDK, class TensorDV,
|
||||
class Fusion
|
||||
>
|
||||
void fmha_bwd_reference(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||
TensorDQ mDQ, TensorDK mDK, TensorDV mDV,
|
||||
Fusion fusion
|
||||
) {
|
||||
fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||
fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||
fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
156
examples/88_hopper_fmha/reference/fmha_reference.hpp
Normal file
156
examples/88_hopper_fmha/reference/fmha_reference.hpp
Normal file
@ -0,0 +1,156 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ,
|
||||
class TensorK,
|
||||
class TensorV,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Fusion
|
||||
>
|
||||
void __global__ fmha_reference_kernel(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) {
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) {
|
||||
acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L);
|
||||
}
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc;
|
||||
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
mS[idx_K] = static_cast<Element>(frag(0) * softmax_scale);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAccumulator maxS = -std::numeric_limits<ElementAccumulator>::infinity();
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
maxS = std::max<ElementAccumulator>(maxS, mS[idx_K]);
|
||||
}
|
||||
if (maxS == -std::numeric_limits<ElementAccumulator>::infinity()) maxS = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||
mS[idx_K] = static_cast<Element>(exp(mS[idx_K] - maxS));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAccumulator sum = 0;
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
sum += mS[idx_K];
|
||||
}
|
||||
|
||||
Element scale = static_cast<Element>(1.0 / sum);
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||
acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale;
|
||||
}
|
||||
mO(idx_Q, idx_D, idx_L) = static_cast<Element>(acc);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mLSE(idx_Q, idx_L) = log(sum) + maxS;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShape,
|
||||
class TensorQ,
|
||||
class TensorK,
|
||||
class TensorV,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Fusion
|
||||
>
|
||||
void fmha_reference(
|
||||
ProblemShape problem_shape,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Fusion fusion
|
||||
) {
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mO), size<2>(mO), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
|
||||
|
||||
if (shared_mem >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||
auto result = cudaFuncSetAttribute(
|
||||
fmha_reference_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, Fusion>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
shared_mem);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fmha_reference_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
129
examples/88_hopper_fmha/reference/reference_abs_error.hpp
Normal file
129
examples/88_hopper_fmha/reference/reference_abs_error.hpp
Normal file
@ -0,0 +1,129 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
template<typename Element>
|
||||
__global__ void reference_abs_diff_kernel(
|
||||
Element* data, Element* data_ref, size_t count,
|
||||
double* max_diff, double* sum_diff,
|
||||
bool print_diff
|
||||
) {
|
||||
double thread_max_diff = 0;
|
||||
double thread_sum_diff = 0;
|
||||
|
||||
__shared__ double block_max_diff;
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
double diff = fabs(data[i] - data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
|
||||
for (int i = 0; i < blockDim.x; i++) {
|
||||
if (i == threadIdx.x) {
|
||||
if (i == 0) {
|
||||
block_max_diff = thread_max_diff;
|
||||
block_sum_diff = thread_sum_diff;
|
||||
} else {
|
||||
block_max_diff = fmax(block_max_diff, thread_max_diff);
|
||||
block_sum_diff += thread_sum_diff;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
atomicAdd(sum_diff, block_sum_diff);
|
||||
|
||||
for (;;) {
|
||||
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
|
||||
double prev_diff = reinterpret_cast<double const&>(prev);
|
||||
double new_max_diff = fmax(block_max_diff, prev_diff);
|
||||
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
|
||||
if (found == prev) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Element>
|
||||
void reference_abs_diff(
|
||||
cutlass::DeviceAllocation<Element> const& data,
|
||||
cutlass::DeviceAllocation<Element> const& data_ref,
|
||||
double& max_diff, double& mean_diff
|
||||
) {
|
||||
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
|
||||
|
||||
cutlass::DeviceAllocation<double> result;
|
||||
result.reset(2);
|
||||
assert(data.size() == data_ref.size());
|
||||
|
||||
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Memset failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 block(256, 1, 1);
|
||||
dim3 grid(1024, 1, 1);
|
||||
reference_abs_diff_kernel<<<block, grid>>>(
|
||||
data.get(), data_ref.get(), data.size(),
|
||||
result.get(), result.get() + 1, kPrintDiff);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Difference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
double result_host[2];
|
||||
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Copy failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
max_diff = result_host[0];
|
||||
mean_diff = result_host[1] / static_cast<double>(data.size());
|
||||
}
|
||||
@ -163,6 +163,7 @@ foreach(EXAMPLE
|
||||
82_blackwell_distributed_gemm
|
||||
83_blackwell_sparse_gemm
|
||||
84_blackwell_narrow_precision_sparse_gemm
|
||||
88_hopper_fmha
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -55,3 +55,7 @@ cutlass_example_add_executable(
|
||||
tiled_copy.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
cute_tutorial_tiled_copy_if
|
||||
tiled_copy_if.cu
|
||||
)
|
||||
|
||||
@ -506,13 +506,13 @@ int main(int argc, char** argv)
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major < 8) {
|
||||
std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl;
|
||||
if (props.major != 9) {
|
||||
std::cout << "This example requires NVIDIA's Hopper Architecture GPU with compute capability 90a" << std::endl;
|
||||
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
@ -604,7 +604,7 @@ int main(int argc, char** argv)
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -461,7 +461,7 @@ int main(int argc, char** argv)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
int m = 512;
|
||||
if (argc >= 2)
|
||||
@ -553,7 +553,7 @@ int main(int argc, char** argv)
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#else
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
297
examples/cute/tutorial/tiled_copy_if.cu
Normal file
297
examples/cute/tutorial/tiled_copy_if.cu
Normal file
@ -0,0 +1,297 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
// This example extends `tiled_copy` using predicate tensors to guard memory accesses performed
|
||||
// by `cute::copy_if()`. This enables tensors to have shapes that are not integer multiples of
|
||||
// block sizes.
|
||||
//
|
||||
// This is accomplished by instantiating a tensor of coordinates which correspond to tensor elements
|
||||
// to be accessed and then computing a predicate tensor which masks accesses. The example demonstrates
|
||||
// how constructing of an identity tensor containing coordinates and a predicate tensor containing
|
||||
// mask bits can be implemented using the same CuTe operations used to tile the tensors in
|
||||
// Global Memory.
|
||||
//
|
||||
// This example implements two variants:
|
||||
// - copy_if_kernel() uses `cute::local_partition()` to construct each thread's slice
|
||||
// - copy_if_kernel_vectorized() uses `make_tiled_copy() to implement vectorized memory accesses.
|
||||
//
|
||||
// The tensor shapes and strides must be divisible by the shape of the vector access.
|
||||
//
|
||||
|
||||
/// Simple copy kernel.
|
||||
//
|
||||
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
|
||||
template <class TensorS, class TensorD, class BlockShape, class ThreadLayout>
|
||||
__global__ void copy_if_kernel(TensorS S, TensorD D, BlockShape block_shape, ThreadLayout)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
|
||||
auto shape_S = shape(S);
|
||||
Tensor C = make_identity_tensor(shape_S);
|
||||
// Construct a predicate tensor which compares the coordinates with the original shape
|
||||
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
|
||||
|
||||
// Tile the input tensor into blocks
|
||||
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
||||
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
// Construct a partitioning of the tile among threads with the given thread arrangement.
|
||||
|
||||
// Concept: Tensor ThrLayout ThrIndex
|
||||
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
|
||||
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
|
||||
Tensor thr_tile_P = local_partition(tile_P, ThreadLayout{}, threadIdx.x);
|
||||
|
||||
// Copy from GMEM to GMEM using `thr_tile_P` to guard accesses.
|
||||
copy_if(thr_tile_P, thr_tile_S, thr_tile_D);
|
||||
}
|
||||
|
||||
/// Vectorized copy kernel.
|
||||
///
|
||||
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
|
||||
/// has the precondition that pointers are aligned to the vector size.
|
||||
///
|
||||
template <class TensorS, class TensorD, class BlockShape, class Tiled_Copy>
|
||||
__global__ void copy_if_kernel_vectorized(TensorS S, TensorD D, BlockShape block_shape, Tiled_Copy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
|
||||
auto shape_S = shape(S);
|
||||
Tensor C = make_identity_tensor(shape_S);
|
||||
// Construct a predicate tensor which compares the coordinates with the original shape
|
||||
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
|
||||
|
||||
// Tile the input tensor into blocks
|
||||
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
||||
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
//
|
||||
// Construct a Tensor corresponding to each thread's slice.
|
||||
//
|
||||
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CPY, CPY_M, CPY_N)
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CPY, CPY_M, CPY_N)
|
||||
Tensor thr_tile_P = thr_copy.partition_S(tile_P); // (CPY, CPY_M, CPY_N)
|
||||
|
||||
#if 0
|
||||
// Copy from GMEM to GMEM
|
||||
copy_if(tiled_copy, thr_tile_P, thr_tile_S, thr_tile_D);
|
||||
#else
|
||||
// make_fragment_like() constructs a tensor in RMEM with the same shape as thr_tile_S.
|
||||
Tensor frag = make_fragment_like(thr_tile_S);
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy_if(tiled_copy, thr_tile_P, thr_tile_S, frag);
|
||||
copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Main function
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
//
|
||||
// Given a 2D shape, perform an efficient copy
|
||||
//
|
||||
|
||||
using namespace cute;
|
||||
using Element = float;
|
||||
|
||||
// Define a tensor shape with dynamic extents (m, n)
|
||||
auto tensor_shape = make_shape(528, 300);
|
||||
|
||||
thrust::host_vector<Element> h_S(size(tensor_shape));
|
||||
thrust::host_vector<Element> h_D(size(tensor_shape));
|
||||
|
||||
//
|
||||
// Initialize
|
||||
//
|
||||
|
||||
for (size_t i = 0; i < h_S.size(); ++i) {
|
||||
h_S[i] = static_cast<Element>(i);
|
||||
h_D[i] = Element{};
|
||||
}
|
||||
|
||||
thrust::device_vector<Element> d_S = h_S;
|
||||
thrust::device_vector<Element> d_D = h_D;
|
||||
thrust::device_vector<Element> d_Zero = h_D;
|
||||
|
||||
//
|
||||
// Make tensors
|
||||
//
|
||||
|
||||
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
|
||||
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
|
||||
|
||||
//
|
||||
// Partition
|
||||
//
|
||||
|
||||
// Define a statically sized block (M, N).
|
||||
//
|
||||
// Note, by convention, capital letters are used to represent static modes.
|
||||
auto block_shape = make_shape(Int<128>{}, Int<64>{});
|
||||
|
||||
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
|
||||
// shape, and modes (m', n') correspond to the number of tiles.
|
||||
//
|
||||
// These will be used to determine the CUDA kernel grid dimensinos.
|
||||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
|
||||
|
||||
// Describes the layout of threads which is then replicated to tile 'block_shape.'
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (ThrM, ThrN)
|
||||
|
||||
//
|
||||
// Determine grid and block dimensions
|
||||
//
|
||||
|
||||
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
|
||||
dim3 blockDim(size(thr_layout));
|
||||
|
||||
//
|
||||
// Launch the kernel
|
||||
//
|
||||
|
||||
// copy_if()
|
||||
copy_if_kernel<<< gridDim, blockDim >>>(
|
||||
tensor_S,
|
||||
tensor_D,
|
||||
block_shape,
|
||||
thr_layout);
|
||||
|
||||
cudaError result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
h_D = d_D;
|
||||
|
||||
//
|
||||
// Verification
|
||||
//
|
||||
|
||||
auto verify = [](thrust::host_vector<Element> const &S, thrust::host_vector<Element> const &D){
|
||||
|
||||
int32_t errors = 0;
|
||||
int32_t const kErrorLimit = 10;
|
||||
|
||||
if (S.size() != D.size()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < D.size(); ++i) {
|
||||
if (S[i] != D[i]) {
|
||||
std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << i << "]: " << D[i] << std::endl;
|
||||
|
||||
if (++errors >= kErrorLimit) {
|
||||
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
|
||||
return errors;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors;
|
||||
};
|
||||
|
||||
if (verify(h_D, h_S)) {
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "Success." << std::endl;
|
||||
}
|
||||
|
||||
thrust::copy(d_Zero.begin(), d_Zero.end(), d_D.begin());
|
||||
|
||||
// Construct a TiledCopy with a specific access pattern.
|
||||
// This version uses a
|
||||
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
|
||||
// (2) Layout-of-Values that each thread will access.
|
||||
|
||||
// Value arrangement per thread
|
||||
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
|
||||
|
||||
// Define `AccessType` which controls the size of the actual memory access instruction.
|
||||
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
|
||||
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
|
||||
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
|
||||
|
||||
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
|
||||
using Atom = Copy_Atom<CopyOp, Element>;
|
||||
|
||||
// Construct tiled copy, a tiling of copy atoms.
|
||||
//
|
||||
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||
// reads. Alternative value layouts are also possible, though incompatible layouts
|
||||
// will result in compile time errors.
|
||||
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
|
||||
thr_layout, // thread layout (e.g. 32x4 Col-Major)
|
||||
val_layout); // value layout (e.g. 4x1)
|
||||
|
||||
// copy_if() with vectorization
|
||||
copy_if_kernel_vectorized<<< gridDim, blockDim >>>(
|
||||
tensor_S,
|
||||
tensor_D,
|
||||
block_shape,
|
||||
tiled_copy);
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
h_D = d_D;
|
||||
|
||||
if (verify(h_D, h_S)) {
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "Success." << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
392
examples/python/CuTeDSL/ampere/elementwise_add.py
Normal file
392
examples/python/CuTeDSL/ampere/elementwise_add.py
Normal file
@ -0,0 +1,392 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
from typing import Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
"""
|
||||
An Elementwise Addition Example using CuTe DSL.
|
||||
|
||||
This example kernel copies data from global memory to register memory (rmem), performs the elementwise
|
||||
addition operation, and stores the result back to global memory.
|
||||
|
||||
Primary goals of this example are to demonstrate how basic global memory copies can be expressed in
|
||||
CuTe DSL and illustrate canonical partitioning patterns in CuTe. It also implements canonical
|
||||
predication for tensors whose shape is not multiple of tile size to guard OOB reads.
|
||||
|
||||
Thread-value (or TV) layouts are central to canonical partitioning patterns in CuTe. They provide a
|
||||
mapping from thread and a thread's value to the set of coordinates within a tile that we have sliced
|
||||
out from a data tensor.
|
||||
|
||||
The input tensors are row-major layout, that leading dimension is the right most dimension. In order
|
||||
to efficiently copy data from global memory, we must map threads contiguously on row dimension.
|
||||
|
||||
Thread ID mapping to 2D coordinates with layout `(4,32):(32,1)`:
|
||||
|
||||
+----+----+----+----+-----+----+
|
||||
| | 0 | 1 | 2 | ... | 31 |
|
||||
+----+----+----+----+-----+----+
|
||||
| 0 | T0 | T1 | T2 | ... | T31|
|
||||
+----+----+----+----+-----+----+
|
||||
| 1 |T32 |T33 |T34 | ... |T63 |
|
||||
+----+----+----+----+-----+----+
|
||||
| 2 |T64 |T65 |T66 | ... |T95 |
|
||||
+----+----+----+----+-----+----+
|
||||
| 3 |T96 |T97 |T98 | ... |T127|
|
||||
+----+----+----+----+-----+----+
|
||||
|
||||
As Ampere GPU supports a maximum of 128bit per load/store instruction and each element is 32bit, we
|
||||
can load 4 elements per instruction. Having additional contiguous values allows for vectorization
|
||||
across threads (coalesced accesses) and is required for saturating the memory bandwidth.
|
||||
|
||||
We use `(4,4):(4,1)` as the val layout in this example. Notice that the major mode is the same as
|
||||
the major mode of the input tensor - without which vectorization would not be possible.
|
||||
|
||||
If you already know the TV layout you want to use for your tiled copy, CuTe DSL provides utility
|
||||
`cute.make_layout_tv` to build the tiled copy type around it and the atom of your choice.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
thr_layout = cute.make_layout((4, 32), stride=(32, 1))
|
||||
val_layout = cute.make_layout((4, 4), stride=(4, 1))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
# Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN))
|
||||
gA = cute.zipped_divide(mA, tiler_mn)
|
||||
|
||||
where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps
|
||||
thread index and inter-thread index of data array per thread to logical coordinates of elements in
|
||||
input and output tensors.
|
||||
|
||||
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
blkA = gA[((None, None), bidx)] # (TileM,TileN)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
|
||||
# get slice of tiled_copy_A for current thread
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
|
||||
# partition per thread block tensor as source of tiled copy
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
|
||||
# allocate fragment for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
|
||||
# copy data from global memory to register memory
|
||||
cute.copy(copy_atom_load, thrA, frgA)
|
||||
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/elementwise_add.py --M 3 --N 12
|
||||
python examples/ampere/elementwise_add.py --M 1024 --N 512
|
||||
python examples/ampere/elementwise_add.py --M 1024 --N 1024 --benchmark --warmup_iterations 2 --iterations 1000
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Don't iterate too many times when profiling with ncu
|
||||
ncu python examples/ampere/elementwise_add.py --M 2048 --N 2048 --benchmark --iterations 10 --skip_ref_check
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_add_kernel(
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout,
|
||||
tiler_mn: cute.Shape,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
# logical id -> address
|
||||
blk_coord = ((None, None), bidx)
|
||||
blkA = gA[blk_coord] # (TileM,TileN)
|
||||
blkB = gB[blk_coord] # (TileM,TileN)
|
||||
blkC = gC[blk_coord] # (TileM,TileN)
|
||||
blkCrd = cC[blk_coord] # (TileM, TileN)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
||||
print(f"[DSL INFO] blkA = {blkA.type}")
|
||||
print(f"[DSL INFO] blkB = {blkB.type}")
|
||||
print(f"[DSL INFO] blkC = {blkC.type}")
|
||||
print(f"[DSL INFO] blkCrd = {blkCrd.type}")
|
||||
|
||||
# # declare the atoms which will be used later for memory copy
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
thrB = thr_copy_B.partition_S(blkB)
|
||||
thrC = thr_copy_C.partition_S(blkC)
|
||||
|
||||
# allocate fragments for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
frgB = cute.make_fragment_like(thrB)
|
||||
frgC = cute.make_fragment_like(thrC)
|
||||
|
||||
thrCrd = thr_copy_C.partition_S(blkCrd)
|
||||
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread:")
|
||||
print(f"[DSL INFO] thrA = {thrA.type}")
|
||||
print(f"[DSL INFO] thrB = {thrB.type}")
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
for i in cutlass.range_dynamic(0, cute.size(frgPred), 1):
|
||||
val = cute.elem_less(thrCrd[i], shape)
|
||||
frgPred[i] = val
|
||||
|
||||
# Print per thread predicate mask
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.printf("block_dim = {}", cute.arch.grid_dim())
|
||||
# cute.printf("shape = {}", shape)
|
||||
# cute.print_tensor(thrA)
|
||||
# cute.print_tensor(thrB)
|
||||
# cute.print_tensor(frgPred)
|
||||
|
||||
##########################################################
|
||||
# Move data to reg address space
|
||||
##########################################################
|
||||
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.print_tensor(frgA)
|
||||
# cute.print_tensor(frgB)
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = frgA.load() + frgB.load()
|
||||
|
||||
# Save the results back to registers. Here we reuse b's registers.
|
||||
frgC.store(result)
|
||||
|
||||
# Copy the results back to c
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128):
|
||||
dtype = mA.element_type
|
||||
vector_size = copy_bits // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
print(f"[DSL INFO] Input Tensors:")
|
||||
print(f"[DSL INFO] mA = {mA.type}")
|
||||
print(f"[DSL INFO] mB = {mB.type}")
|
||||
|
||||
print(f"[DSL INFO] Tiling Parameters:")
|
||||
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
||||
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
||||
|
||||
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
print(f"[DSL INFO] Tiled Tensors:")
|
||||
print(f"[DSL INFO] gA = {gA.type}")
|
||||
print(f"[DSL INFO] gB = {gB.type}")
|
||||
print(f"[DSL INFO] gC = {gC.type}")
|
||||
|
||||
idC = cute.make_identity_tensor(mC.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC.type}")
|
||||
|
||||
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_elementwise_add(
|
||||
M,
|
||||
N,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
is_a_dynamic_layout=False,
|
||||
is_b_dynamic_layout=False,
|
||||
is_result_dynamic_layout=False,
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
warmup_iterations=2,
|
||||
iterations=200,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
print(f"\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
if dtype.is_integer:
|
||||
a = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype)
|
||||
else:
|
||||
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
|
||||
c = torch.zeros_like(a)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
if not is_a_dynamic_layout:
|
||||
a_tensor = from_dlpack(a).mark_layout_dynamic()
|
||||
else:
|
||||
a_tensor = a
|
||||
|
||||
if not is_b_dynamic_layout:
|
||||
b_tensor = from_dlpack(b).mark_layout_dynamic()
|
||||
else:
|
||||
b_tensor = b
|
||||
|
||||
if not is_result_dynamic_layout:
|
||||
c_tensor = from_dlpack(c).mark_layout_dynamic()
|
||||
else:
|
||||
c_tensor = c
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(elementwise_add, a_tensor, b_tensor, c_tensor)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing vector add kernel...")
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(a + b, c)
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Use the current stream for CUDA events instead of the default stream
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time = elapsed_time / iterations
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
|
||||
)
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=1024, type=int)
|
||||
parser.add_argument("--N", default=1024, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
run_elementwise_add(
|
||||
args.M,
|
||||
args.N,
|
||||
dtype=cutlass.Float32,
|
||||
is_a_dynamic_layout=True,
|
||||
is_b_dynamic_layout=True,
|
||||
is_result_dynamic_layout=True,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
)
|
||||
print("\nPASS")
|
||||
390
examples/python/CuTeDSL/ampere/elementwise_apply.py
Normal file
390
examples/python/CuTeDSL/ampere/elementwise_apply.py
Normal file
@ -0,0 +1,390 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import operator
|
||||
import torch
|
||||
from typing import Type
|
||||
import time
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
An Elementwise Apply Example using CuTe DSL.
|
||||
|
||||
This example kernel demonstrates the meta-programming capability of the CuTe DSL by allowing
|
||||
customization of elementwise operations through lambda functions. The kernel copies data from
|
||||
global memory to register memory (rmem), applies a user-defined operation to the elements,
|
||||
and stores the result back to global memory.
|
||||
|
||||
Primary goals of this example:
|
||||
1. Demonstrate meta-programming capability by passing lambda functions to customize elementwise operations
|
||||
2. Show how to apply different operations (add, multiply, etc.) using the same kernel structure
|
||||
3. Illustrate how to parameterize CUDA kernels with operation types at compile time
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Run with addition operation
|
||||
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op add
|
||||
|
||||
# Run with multiplication operation
|
||||
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op mul
|
||||
|
||||
# Run with subtraction operation
|
||||
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op sub
|
||||
|
||||
# Benchmark performance
|
||||
python examples/ampere/elementwise_apply.py --M 2048 --N 2048 --op add --benchmark --warmup_iterations 2 --iterations 10
|
||||
|
||||
The example demonstrates how to express complex CUDA kernels with customizable operations
|
||||
while maintaining high performance through efficient memory access patterns.
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_apply_kernel(
|
||||
op: cutlass.Constexpr,
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout, # (tid, vid) -> logic coord
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
cta_coord = ((None, None), bidx)
|
||||
# logical coord -> address
|
||||
ctaA = gA[cta_coord] # (TileM, TileN)
|
||||
ctaB = gB[cta_coord] # (TileM, TileN)
|
||||
ctaC = gC[cta_coord] # (TileM, TileN)
|
||||
ctaCrd = cC[cta_coord] # (TileM, TileN)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
||||
print(f"[DSL INFO] ctaA = {ctaA.type}")
|
||||
print(f"[DSL INFO] ctaB = {ctaB.type}")
|
||||
print(f"[DSL INFO] ctaC = {ctaC.type}")
|
||||
print(f"[DSL INFO] ctaCrd = {ctaCrd.type}")
|
||||
|
||||
# compose with CTA TV layout
|
||||
# (tid, vid) -> address
|
||||
tidfrgA = cute.composition(ctaA, tv_layout)
|
||||
tidfrgB = cute.composition(ctaB, tv_layout)
|
||||
tidfrgC = cute.composition(ctaC, tv_layout)
|
||||
tidfrgCrd = cute.composition(ctaCrd, tv_layout)
|
||||
# print(f"{tv_layout = }")
|
||||
# print(f"{tidfrgA = }")
|
||||
|
||||
thr_coord = (tidx, (None, None))
|
||||
|
||||
# slice for threads
|
||||
# vid -> address
|
||||
thrA = tidfrgA[thr_coord] # (V)
|
||||
thrB = tidfrgB[thr_coord] # (V)
|
||||
thrC = tidfrgC[thr_coord] # (V)
|
||||
thrCrd = tidfrgCrd[thr_coord]
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread:")
|
||||
print(f"[DSL INFO] thrA = {thrA.type}")
|
||||
print(f"[DSL INFO] thrB = {thrB.type}")
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
# allocate fragments for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA, gA.element_type)
|
||||
frgB = cute.make_fragment_like(thrB, gB.element_type)
|
||||
frgC = cute.make_fragment_like(thrC, gC.element_type)
|
||||
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
for i in cutlass.range_dynamic(cute.size(frgPred), unroll=1):
|
||||
frgPred[i] = cute.elem_less(thrCrd[i], shape)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.print_tensor(frgPred)
|
||||
|
||||
##########################################################
|
||||
# Move data to reg address space
|
||||
##########################################################
|
||||
|
||||
# declare the atoms which will be used later for memory copy
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
gA.element_type,
|
||||
num_bits_per_copy=gA.element_type.width,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
gC.element_type,
|
||||
num_bits_per_copy=gC.element_type.width,
|
||||
)
|
||||
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = op(frgA.load(), frgB.load())
|
||||
|
||||
# Save the results back to registers. Here we reuse b's registers.
|
||||
frgC.store(result)
|
||||
|
||||
# Copy the results back to c
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_apply(
|
||||
op: cutlass.Constexpr,
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
result: cute.Tensor,
|
||||
):
|
||||
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
|
||||
CuTe Python and store to result tensor.
|
||||
|
||||
:param op: Binary operator or lambda function to apply element-wise
|
||||
:type op: cutlass.Constexpr
|
||||
:param a: First input tensor
|
||||
:type a: cute.Tensor
|
||||
:param b: Second input tensor
|
||||
:type b: cute.Tensor
|
||||
:param result: Output tensor to store the results of op(a, b)
|
||||
:type result: cute.Tensor
|
||||
:return: None
|
||||
:rtype: None
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 1: Adding two tensors
|
||||
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, device="cuda")
|
||||
y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32, device="cuda")
|
||||
result = torch.empty_like(x)
|
||||
elementwise_apply(operator.add, from_dlpack(x), from_dlpack(y), from_dlpack(result))
|
||||
# result:
|
||||
# tensor([[6.0, 8.0],
|
||||
# [10.0, 12.0]], device='cuda:0')
|
||||
|
||||
# Example 2: Using a lambda function
|
||||
elementwise_apply(lambda a, b: a * a + b * b, from_dlpack(x), from_dlpack(y), from_dlpack(result))
|
||||
# result:
|
||||
# tensor([[ 2., 8.],
|
||||
# [ 54., 512.]], device='cuda:0')
|
||||
"""
|
||||
|
||||
# Baseline: naive TV layout
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (512, 4) tile
|
||||
# * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad
|
||||
# tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1)))
|
||||
# cta_tiler = (512, 4)
|
||||
|
||||
# Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout)
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (4, 512) tile
|
||||
# * tidx maps to mode-1 which is leading mode of input tensor for coalesced load
|
||||
# tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1)))
|
||||
# cta_tiler = (4, 512)
|
||||
|
||||
# Opt-2: 2D tile but worse
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (128, 16) logical tile
|
||||
# * V layout is bad as contiguous mode is not on right-most
|
||||
# * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most )
|
||||
# tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128)))
|
||||
# cta_tiler = (128, 16)
|
||||
|
||||
# Opt-3: SOL with 2D thread tile
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (16, 128) logical tile
|
||||
# * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store
|
||||
thr_layout = cute.make_layout((4, 32), stride=(32, 1))
|
||||
val_layout = cute.make_layout((4, 4), stride=(4, 1))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
print(f"[DSL INFO] Input Tensors:")
|
||||
print(f"[DSL INFO] a = {a.type}")
|
||||
print(f"[DSL INFO] b = {b.type}")
|
||||
print(f"[DSL INFO] result = {result.type}")
|
||||
|
||||
print(f"[DSL INFO] Tiling Parameters:")
|
||||
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
||||
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
||||
|
||||
gA = cute.zipped_divide(a, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
||||
gB = cute.zipped_divide(b, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
||||
gC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
||||
|
||||
print(f"[DSL INFO] Tiled Tensors:")
|
||||
print(f"[DSL INFO] gA = {gA.type}")
|
||||
print(f"[DSL INFO] gB = {gB.type}")
|
||||
print(f"[DSL INFO] gC = {gC.type}")
|
||||
|
||||
idC = cute.make_identity_tensor(result.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC.type}")
|
||||
|
||||
# Launch the kernel asynchronously
|
||||
# Async token(s) can also be specified as dependencies
|
||||
elementwise_apply_kernel(
|
||||
op,
|
||||
gA,
|
||||
gB,
|
||||
gC,
|
||||
cC,
|
||||
result.shape,
|
||||
tv_layout,
|
||||
).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_elementwise_apply_and_verify(
|
||||
op,
|
||||
M,
|
||||
N,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
warmup_iterations=2,
|
||||
iterations=100,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
print(f"\nRunning Elementwise Apply test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Measurement iterations: {iterations}\n")
|
||||
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
|
||||
# Allocate tensors with random values.
|
||||
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
c = torch.zeros_like(a)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
epsilon = 1.2
|
||||
if op in (operator.truediv, operator.floordiv):
|
||||
b = torch.where(b == 0, torch.tensor(epsilon), b)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(elementwise_apply, op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing elementwise apply kernel...")
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(op(a, b), c)
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
compiled_func(from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic())
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time = elapsed_time / iterations
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
|
||||
)
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise apply to demonstrate building elementwise kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=128, type=int)
|
||||
parser.add_argument("--N", default=128, type=int)
|
||||
parser.add_argument("--op", default="add", type=str)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
args = parser.parse_args()
|
||||
run_elementwise_apply_and_verify(
|
||||
getattr(operator, args.op),
|
||||
args.M,
|
||||
args.N,
|
||||
dtype=cutlass.Float32,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
)
|
||||
print("\nPASS")
|
||||
1353
examples/python/CuTeDSL/ampere/flash_attention_v2.py
Normal file
1353
examples/python/CuTeDSL/ampere/flash_attention_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
780
examples/python/CuTeDSL/ampere/sgemm.py
Normal file
780
examples/python/CuTeDSL/ampere/sgemm.py
Normal file
@ -0,0 +1,780 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A dense FP32 SIMT GEMM (C = A * B) example using CUTE DSL.
|
||||
- Matrix A is MxK, A can be row-major("K") or column-major("M")
|
||||
- Matrix B is NxK, B can be row-major("N") or column-major("K")
|
||||
- Matrix C is MxN, C can be row-major("N") or column-major("M")
|
||||
|
||||
This GEMM kernel supports the following features:
|
||||
- Utilizes FPU for matrix multiply-accumulate (MMA) operations
|
||||
- Use multistage pipeline to overlap computation and memory access
|
||||
* Shared memory pipeline: hides gmem-to-smem latency.
|
||||
* Register pipeline: overlaps shared memory-to-register transfers with
|
||||
computations and eliminates false data dependencies for
|
||||
better parallelism.
|
||||
- Use vectorized copies
|
||||
- Add padding to reduce bank conflicts in global -> shared memory copies
|
||||
- Use predication to avoid unnecessary copies or copies of stale data
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
2. Perform matrix multiply-accumulate (MMA) operations using simple fused multiply-add atomics.
|
||||
3. Store results from registers (RMEM) to global memory (GMEM).
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/sgemm.py \
|
||||
--mnk 8192,8192,8192 \
|
||||
--a_major m --b_major n --c_major n
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/ampere/sgemm.py \
|
||||
--mnk 8192,8192,8192 \
|
||||
--a_major m --b_major n --c_major n \
|
||||
--skip_ref_check --iterations 2
|
||||
|
||||
Constraints:
|
||||
* Supported input, output, and accumulator data types: fp32
|
||||
* Default tile shape is set to be 128x128x8
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned
|
||||
"""
|
||||
|
||||
|
||||
class SGemm:
|
||||
def __init__(
|
||||
self,
|
||||
cta_tiler: Tuple[int, int, int] = (128, 128, 8),
|
||||
num_stages: int = 3,
|
||||
num_threads: int = 256,
|
||||
):
|
||||
self._cta_tiler = cta_tiler
|
||||
self._num_stages = num_stages
|
||||
self._num_threads = num_threads
|
||||
assert num_threads > 0, "needs at least one thread"
|
||||
assert num_threads % 16 == 0, "multiples of 16 required for MMA thread layout"
|
||||
|
||||
self._bM, self._bN, self._bK = self._cta_tiler
|
||||
assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M"
|
||||
assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N"
|
||||
assert self._num_stages >= 3, "num_stages must be greater than or equal to 3"
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
|
||||
self.c_major_mode = utils.LayoutEnum.from_tensor(mC)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create layouts for shared memory for A and B:
|
||||
# - sA/sB is m/n-major to vectorized copies from shared
|
||||
# memory to registers. This is because the MMA layouts
|
||||
# for sA/sB are also m/n-major
|
||||
# - When gA/gB is k-major, pad 4 elements to reduce bank conflicts
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
padding_a = 4 if self.a_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
|
||||
padding_b = 4 if self.b_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
|
||||
sA_layout = cute.make_layout(
|
||||
(self._bM, self._bK, self._num_stages),
|
||||
stride=(1, (self._bM + padding_a), self._bK * (self._bM + padding_a)),
|
||||
)
|
||||
sB_layout = cute.make_layout(
|
||||
(self._bN, self._bK, self._num_stages),
|
||||
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
|
||||
)
|
||||
|
||||
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
|
||||
mB.element_type, sB_layout
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create copy layouts that will be used for asynchronous
|
||||
# global memory -> shared memory copies:
|
||||
# - The majorness of tA/tB follows the majorness of gA/gB
|
||||
# - For k-major, these layouts will copy values one-by-one from
|
||||
# from global memory, without vectorizing
|
||||
# - For m/n-major, it will vectorize to a 128bit copy for faster
|
||||
# data transfer between global and shared memory, as long
|
||||
# as the alignment of the tensor allows it. Otherwise, it
|
||||
# defaults to a non-vectorized copy
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
tA = cute.make_layout(
|
||||
(self._num_threads // self._bK, self._bK), stride=(self._bK, 1)
|
||||
)
|
||||
tB = cute.make_layout(
|
||||
(self._num_threads // self._bK, self._bK), stride=(self._bK, 1)
|
||||
)
|
||||
vA = cute.make_layout((1, 1))
|
||||
vB = cute.make_layout((1, 1))
|
||||
atom_async_copy_A = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mA.element_type.width,
|
||||
)
|
||||
atom_async_copy_B = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mB.element_type.width,
|
||||
)
|
||||
|
||||
if self.a_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1
|
||||
atom_async_copy_A = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mA.element_type.width * num_vectorized,
|
||||
)
|
||||
major_mode_size = self._bM // num_vectorized
|
||||
tA = cute.make_layout(
|
||||
(major_mode_size, self._num_threads // major_mode_size),
|
||||
stride=(1, major_mode_size),
|
||||
)
|
||||
vA = cute.make_layout((num_vectorized, 1))
|
||||
|
||||
if self.b_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1
|
||||
atom_async_copy_B = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mB.element_type.width * num_vectorized,
|
||||
)
|
||||
major_mode_size = self._bN // num_vectorized
|
||||
tB = cute.make_layout(
|
||||
(major_mode_size, self._num_threads // major_mode_size),
|
||||
stride=(1, major_mode_size),
|
||||
)
|
||||
vB = cute.make_layout((num_vectorized, 1))
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(atom_async_copy_A, tA, vA)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(atom_async_copy_B, tB, vB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create layouts for GEMM:
|
||||
# We tile an MMA atom across a tensor. `atoms_layout` is the layout
|
||||
# of atoms in the tiled MMA. (Because we use an `MmaUniversalOp`,
|
||||
# which has a trivial 1x1x1 MMA trait, `atoms_layout` is also
|
||||
# simply the thread layout for C.) `permutation_tiler` reorders the
|
||||
# elements of the tensor that the tiled MMA is applied to.
|
||||
# Different combinations of `atoms_layout` and `permutation_tiler`
|
||||
# values can create different MMA thread-value patterns.
|
||||
#
|
||||
# Here, the MMA layout is set so that each thread copies four
|
||||
# consecutive elements from shared memory to registers.
|
||||
# `permutation_tiler_M/N` maps the elements handled by each thread
|
||||
# to the permuted element in the tensor.
|
||||
# For increasing indices in the tensor, the thread ID that reads it is:
|
||||
# - (without permutation) ==>
|
||||
# 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 ......
|
||||
# - (with permutation) ==>
|
||||
# 0 0 0 0 1 1 1 1 2 2 2 2 ... 15 15 15 15 0 0 0 0 1 1 1 1 ......
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
atoms_layout = cute.make_layout(
|
||||
(self._num_threads // 16, 16, 1), stride=(16, 1, 0)
|
||||
)
|
||||
if self.c_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
atoms_layout = cute.make_layout(
|
||||
(16, self._num_threads // 16, 1), stride=(1, 16, 0)
|
||||
)
|
||||
op = cute.nvgpu.MmaUniversalOp(cutlass.Float32)
|
||||
permutation_tiler_M = cute.make_layout(
|
||||
(atoms_layout.shape[0], 4), stride=(4, 1)
|
||||
)
|
||||
permutation_tiler_N = cute.make_layout(
|
||||
(atoms_layout.shape[1], 4), stride=(4, 1)
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
op,
|
||||
atoms_layout,
|
||||
permutation_mnk=(permutation_tiler_M, permutation_tiler_N, None),
|
||||
)
|
||||
|
||||
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, 1)
|
||||
grid_dim = *cute.ceil_div(mC.shape, (self._bM, self._bN)), 1
|
||||
|
||||
self.kernel(
|
||||
mA,
|
||||
mB,
|
||||
mC,
|
||||
sA_layout,
|
||||
sB_layout,
|
||||
tiled_copy_A,
|
||||
tiled_copy_B,
|
||||
tiled_mma,
|
||||
epilogue_op,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[cute.size(atoms_layout), 1, 1],
|
||||
smem=smem_size,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
sA_layout: cute.Layout,
|
||||
sB_layout: cute.Layout,
|
||||
tiled_copy_A: cute.TiledCopy,
|
||||
tiled_copy_B: cute.TiledCopy,
|
||||
tiled_mma: cute.TiledMma,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# Thread and block indices
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
tiler_coord = (bidx, bidy, None)
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_K, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
|
||||
)
|
||||
gB = cute.local_tile(
|
||||
mB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, 1, None)
|
||||
)
|
||||
|
||||
# Move the pointer of gA/gB in the `-k`` direction, making the first
|
||||
# tile (instead of the last one) irregular in shape when k is irregular.
|
||||
# We first handle the irregular tile to avoid checking for this
|
||||
# condition within the mainloop.
|
||||
residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2]
|
||||
gA = cute.domain_offset((0, residue_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residue_k, 0), gB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when the problem shape
|
||||
# isn't a multiple of the tile shape. If tApA/B[i] is 0, then do not
|
||||
# do the copy atom associated with index i.
|
||||
# cA: (BLK_M, BLK_K) => (blk_m, blk_k)
|
||||
# cB: (BLK_N, BLK_K) => (blk_n, blk_k)
|
||||
# tAcA: (CPY, CPY_M, CPY_K) => (blk_m, blk_k)
|
||||
# tBcB: (CPY, CPY_N, CPY_K) => (blk_n, blk_k)
|
||||
# tApA: (rest_v, CPY_M, CPY_K), stride=(..., ..., 0)
|
||||
# tBpB: (rest_v, CPY_N, CPY_K), stride=(..., ..., 0)
|
||||
# CPY = (atom_v, rest_v)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Construct identity layout for sA and sB, used for predication
|
||||
mcA = cute.make_identity_tensor(mA.shape)
|
||||
mcB = cute.make_identity_tensor(mB.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
|
||||
)
|
||||
cA = cute.domain_offset((0, residue_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residue_k, 0), cB)
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
# Allocate predicate tensors for m and n
|
||||
tApA = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAsA.shape[0][1],
|
||||
cute.size(tAsA, mode=[1]),
|
||||
cute.size(tAsA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAsA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Allocate predicate tensors for m, n and k for residue k-tile
|
||||
tApA_residue_k = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAsA.shape[0][1],
|
||||
cute.size(tAsA, mode=[1]),
|
||||
cute.size(tAsA, mode=[2]),
|
||||
),
|
||||
stride=(
|
||||
cute.size(tAsA, mode=[1]) * cute.size(tAsA, mode=[2]),
|
||||
cute.size(tAsA, mode=[2]),
|
||||
1,
|
||||
),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB_residue_k = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(
|
||||
cute.size(tBsB, mode=[1]) * cute.size(tBsB, mode=[2]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
1,
|
||||
),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for m/n bounds for mainloop
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# Set predicates for m/n/k bounds for residue k tile
|
||||
for rest_v in range(tApA_residue_k.shape[0]):
|
||||
for m in range(tApA_residue_k.shape[1]):
|
||||
for k in range(tApA_residue_k.shape[2]):
|
||||
coord_A = tAcA[(0, rest_v), m, k, 0]
|
||||
tApA_residue_k[rest_v, m, k] = cute.elem_less(
|
||||
(coord_A[0], cutlass.Int32(-1)), (mA.shape[0], coord_A[1])
|
||||
)
|
||||
for rest_v in range(tBpB_residue_k.shape[0]):
|
||||
for n in range(tBpB_residue_k.shape[1]):
|
||||
for k in range(tBpB_residue_k.shape[2]):
|
||||
coord_B = tBcB[(0, rest_v), n, k, 0]
|
||||
tBpB_residue_k[rest_v, n, k] = cute.elem_less(
|
||||
(coord_B[0], cutlass.Int32(-1)), (mB.shape[0], coord_B[1])
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Start async loads for 0th k-tile, where we take care of the k-residue
|
||||
k_pipe_max = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
gmem_pipe_read = cutlass.Int32(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, gmem_pipe_read],
|
||||
tAsA[None, None, None, 0],
|
||||
pred=tApA_residue_k,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, gmem_pipe_read],
|
||||
tBsB[None, None, None, 0],
|
||||
pred=tBpB_residue_k,
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
gmem_pipe_read = (
|
||||
gmem_pipe_read + 1
|
||||
if gmem_pipe_read + 1 < k_tile_count
|
||||
else cutlass.Int32(0)
|
||||
)
|
||||
# Start async loads for 1st k-tile onwards, no k-residue handling needed
|
||||
for k_tile in range(1, k_pipe_max - 1):
|
||||
if k_tile < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, gmem_pipe_read],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, gmem_pipe_read],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
)
|
||||
|
||||
gmem_pipe_read = (
|
||||
gmem_pipe_read + 1
|
||||
if gmem_pipe_read + 1 < k_tile_count
|
||||
else cutlass.Int32(0)
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# all tiles have been copied from global memory, so clear the
|
||||
# predicate tensor
|
||||
if k_tile_count < k_pipe_max:
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cutlass.Boolean(0)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cutlass.Boolean(0)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Define A/B partitioning and C accumulators.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = cutlass.Int32(0)
|
||||
smem_pipe_write = cutlass.Int32(k_pipe_max - 1)
|
||||
|
||||
tCsA_p = tCsA[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
k_block_max = cute.size(tCrA, mode=[2])
|
||||
|
||||
if k_block_max > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(k_pipe_max - 2)
|
||||
cute.arch.barrier()
|
||||
# Prefetch the first rmem from the first k-tile
|
||||
cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0])
|
||||
cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0])
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
for _ in cutlass.range_dynamic(k_tile_count, unroll=1):
|
||||
for k_block in range(k_block_max):
|
||||
if k_block == k_block_max - 1:
|
||||
tCsA_p = tCsA[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(k_pipe_max - 2)
|
||||
cute.arch.barrier()
|
||||
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % k_block_max # static
|
||||
cute.autovec_copy(
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA[None, None, k_block_next],
|
||||
)
|
||||
cute.autovec_copy(
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB[None, None, k_block_next],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and
|
||||
# compute instructions, we intentionally use the sequence:
|
||||
# copy A, perform GEMM, then copy B.
|
||||
if k_block == 0:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, gmem_pipe_read],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
# Use predicates because the m-mode may be irregular
|
||||
pred=tApA,
|
||||
)
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, gmem_pipe_read],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
# Use predicates because the n-mode may be irregular
|
||||
pred=tBpB,
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == k_pipe_max:
|
||||
smem_pipe_read = cutlass.Int32(0)
|
||||
# After copying all tiles, we avoid clearing the predicate
|
||||
# tensor in the `mainloop` to prevent increasing its
|
||||
# instruction count. Instead, we continue copying the
|
||||
# first tile, though it won't be used. The 0-th tile is not
|
||||
# copied due to its irregular shape, which could lead to
|
||||
# illegal memory accesses.
|
||||
gmem_pipe_read = (
|
||||
gmem_pipe_read + 1
|
||||
if gmem_pipe_read + 1 < k_tile_count
|
||||
else cutlass.Int32(1)
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue
|
||||
# Applies the epilogue operation to the accumulated results and copies
|
||||
# them without vectorization.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.barrier()
|
||||
tCrC.store(epilogue_op(tCrC.load()))
|
||||
|
||||
# predicate
|
||||
cC = cute.make_identity_tensor(gC.shape)
|
||||
tCpC = thr_mma.partition_C(cC)
|
||||
predC = cute.make_fragment(tCrC.layout, cutlass.Boolean)
|
||||
residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx
|
||||
residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy
|
||||
for i in range(cute.size(tCrC.shape)):
|
||||
predC[i] = cute.elem_less(tCpC[i], (residue_m, residue_n))
|
||||
numIterM = cute.size(tCrC, mode=[1])
|
||||
numIterN = cute.size(tCrC, mode=[2])
|
||||
atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type)
|
||||
cute.copy(atom, tCrC, tCgC, pred=predC)
|
||||
return
|
||||
|
||||
|
||||
def main(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
problem_shape: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
):
|
||||
torch.manual_seed(1024)
|
||||
M, N, K = problem_shape
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
|
||||
# is_mode0_major: (mode1, mode0) -> (mode0, mode1)
|
||||
# else: (mode0, mode1) -> (mode0, mode1)
|
||||
shape = (mode1, mode0) if is_mode0_major else (mode0, mode1)
|
||||
permute_order = (1, 0) if is_mode0_major else (0, 1)
|
||||
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-5, 5)
|
||||
.to(dtype=dtype)
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
|
||||
a = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
|
||||
b = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
|
||||
c = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
|
||||
|
||||
divisibility_a = a.shape[1] if a_major == "k" else a.shape[0]
|
||||
divisibility_b = b.shape[1] if b_major == "k" else b.shape[0]
|
||||
divisibility_c = c.shape[1] if c_major == "n" else c.shape[0]
|
||||
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
divisibility=divisibility_a,
|
||||
)
|
||||
)
|
||||
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
divisibility=divisibility_b,
|
||||
)
|
||||
)
|
||||
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
divisibility=divisibility_c,
|
||||
)
|
||||
)
|
||||
|
||||
sgemm = SGemm()
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Use the current stream for CUDA events instead of the default stream
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {elapsed_time / iterations:.4f} ms")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
if not skip_ref_check:
|
||||
print("Verifying results...")
|
||||
ref = torch.einsum("mk,nk->mn", a, b)
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--mnk", type=parse_comma_separated_ints, default=(256, 256, 64)
|
||||
)
|
||||
parser.add_argument("--a_major", choices=["k", "m"], default="k")
|
||||
parser.add_argument("--b_major", choices=["k", "n"], default="k")
|
||||
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running SIMT GEMM example:")
|
||||
main(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.mnk,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
)
|
||||
print("PASS")
|
||||
200
examples/python/CuTeDSL/ampere/smem_allocator.py
Normal file
200
examples/python/CuTeDSL/ampere/smem_allocator.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass
|
||||
import torch
|
||||
import numpy as np
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL.
|
||||
|
||||
This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL.
|
||||
It shows various ways to allocate different data structures in shared memory:
|
||||
|
||||
1. Struct allocation with natural and strict alignment
|
||||
2. Raw memory block allocation with custom alignment
|
||||
3. Array allocation with automatic alignment
|
||||
4. Tensor allocation with layout specification
|
||||
|
||||
The example includes:
|
||||
- Shared storage struct with mixed alignment requirements
|
||||
- Memory allocation patterns for different data types
|
||||
- Tensor operations on allocated memory
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/smem_allocator.py
|
||||
|
||||
The example will allocate shared memory, perform tensor operations, and verify the results.
|
||||
"""
|
||||
|
||||
|
||||
@cute.struct
|
||||
class complex:
|
||||
real: cutlass.Float32
|
||||
imag: cutlass.Float32
|
||||
|
||||
|
||||
# SharedStorage size is 512, alignment is 128
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
# struct elements with natural alignment
|
||||
a: cute.struct.MemRange[cutlass.Float32, 32] # array
|
||||
b: cutlass.Int64 # saclar
|
||||
c: complex # nested struct
|
||||
# struct elements with strict alignment
|
||||
x: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Float32, 32],
|
||||
128,
|
||||
]
|
||||
y: cute.struct.Align[cutlass.Int32, 8]
|
||||
z: cute.struct.Align[complex, 16]
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
const_a: cutlass.Constexpr,
|
||||
dst_a: cute.Tensor,
|
||||
const_b: cutlass.Constexpr,
|
||||
dst_b: cute.Tensor,
|
||||
const_c: cutlass.Constexpr,
|
||||
dst_c: cute.Tensor,
|
||||
):
|
||||
# Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize
|
||||
# Note: alignment of inital allocator base ptr is 1024
|
||||
allocator = cutlass.utils.SmemAllocator()
|
||||
# base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory)
|
||||
|
||||
# -- Allocate a struct --
|
||||
# Note: when specified alignment, max(alignment, alignof(struct)) will be applied
|
||||
# reserves the section of struct in smem, elements in the struct can be accessed by ptr
|
||||
struct_in_smem = allocator.allocate(SharedStorage)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct)
|
||||
|
||||
# -- Allocate a block of memory --
|
||||
# reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr
|
||||
section_in_smem = allocator.allocate(64, byte_alignment=128)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section)
|
||||
|
||||
# -- Allocate an array --
|
||||
# reserves an int64 array of size 14 in smem, returns the array base ptr
|
||||
array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array)
|
||||
|
||||
# -- Allocate a tensor --
|
||||
# Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor
|
||||
# Note: iterator swizzle with swizzle layout is currently not supported
|
||||
layout = cute.make_layout((16, 2))
|
||||
tensor_in_smem = allocator.allocate_tensor(
|
||||
element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None
|
||||
)
|
||||
# base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor)
|
||||
|
||||
# ptr<f16, smem, align<1024>>
|
||||
# ptr<i64, smem, align<128>>
|
||||
# ptr<f32, smem, align<8>>
|
||||
print(struct_in_smem.a.data_ptr())
|
||||
print(struct_in_smem.b)
|
||||
print(struct_in_smem.c.real)
|
||||
# ptr<i8, smem, align<512>>
|
||||
print(section_in_smem)
|
||||
# ptr<i64, smem, align<64>>
|
||||
print(array_in_smem)
|
||||
# tensor<ptr<f16, smem, align<32>> o (16,4):(1,16)>
|
||||
print(tensor_in_smem)
|
||||
|
||||
# fill MemRange tensor in struct and copy to dst
|
||||
a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4)))
|
||||
a_tensor.fill(const_a)
|
||||
cute.printf("cute.struct.MemRange: {}", a_tensor)
|
||||
dst_a.store(a_tensor.load())
|
||||
|
||||
# convert block of smem to fill tensor and copy to dst
|
||||
layout = cute.make_layout((8, 2))
|
||||
sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32)
|
||||
sec_tensor = cute.make_tensor(sec_ptr, layout)
|
||||
sec_tensor.fill(const_b)
|
||||
cute.printf("block of memory: {}", sec_tensor)
|
||||
dst_b.store(sec_tensor.load())
|
||||
|
||||
# fill allocated tensor in smem and copy to dst
|
||||
tensor_in_smem.fill(const_c)
|
||||
cute.printf("tensor in smem: {}", tensor_in_smem)
|
||||
dst_c.store(tensor_in_smem.load())
|
||||
|
||||
|
||||
@cute.jit
|
||||
def run_allocation_kernel(
|
||||
const_a: cutlass.Constexpr,
|
||||
dst_a: cute.Tensor,
|
||||
const_b: cutlass.Constexpr,
|
||||
dst_b: cute.Tensor,
|
||||
const_c: cutlass.Constexpr,
|
||||
dst_c: cute.Tensor,
|
||||
):
|
||||
# additional size for the example, 64(section) + 112(array) + 128(tensor) < 384
|
||||
addtional_bytes = 384
|
||||
# Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes
|
||||
kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch(
|
||||
grid=(1, 1, 1),
|
||||
block=(1, 1, 1),
|
||||
smem=SharedStorage.size_in_bytes() + addtional_bytes,
|
||||
)
|
||||
|
||||
|
||||
def veify_allocation_kernel(const_a, const_b, const_c):
|
||||
dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda")
|
||||
dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda")
|
||||
dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda")
|
||||
|
||||
run_allocation_kernel(
|
||||
const_a,
|
||||
from_dlpack(dst_a),
|
||||
const_b,
|
||||
from_dlpack(dst_b),
|
||||
const_c,
|
||||
from_dlpack(dst_c),
|
||||
)
|
||||
|
||||
np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0])
|
||||
np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0])
|
||||
np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# prepare cuda context
|
||||
cutlass.cuda.initialize_cuda_context()
|
||||
# An example for shared memory allocation
|
||||
const_a = 0.5
|
||||
const_b = 1.0
|
||||
const_c = 2.0
|
||||
veify_allocation_kernel(const_a, const_b, const_c)
|
||||
968
examples/python/CuTeDSL/ampere/tensorop_gemm.py
Normal file
968
examples/python/CuTeDSL/ampere/tensorop_gemm.py
Normal file
@ -0,0 +1,968 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL.
|
||||
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
||||
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
||||
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
||||
|
||||
This GEMM kernel supports the following features:
|
||||
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesed global memory access
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
2. Perform matrix multiply-accumulate (MMA) operations.
|
||||
3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM).
|
||||
|
||||
The Ampere tensor core instruction used operates as follows:
|
||||
- Read matrix A from SMEM
|
||||
- Read matrix B from SMEM
|
||||
- Perform MMA operation and store the result in Accumulator(register)
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/tensorop_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \
|
||||
--ab_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major m --b_major n --c_major n
|
||||
|
||||
The above example command computes with M=8192, N=8192, K=8192,
|
||||
batch_count=1. The atom layout's shape is 2x2x1 and the input, mma
|
||||
accumulator, and output data type are set as fp16, fp32 and fp16,
|
||||
respectively.
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/ampere/tensorop_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \
|
||||
--ab_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major m --b_major n --c_major n \
|
||||
--skip_ref_check --iterations 2
|
||||
|
||||
Constraints:
|
||||
* Supported input and output data types: fp16
|
||||
* Support accumulator data types: f32
|
||||
* Default tile shape is set to be 128x128x32
|
||||
* Atom layout's MNK shape is set so that tile shape can be divided by MMA
|
||||
instruction shape
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
||||
i.e, number of elements is a multiple of 8
|
||||
"""
|
||||
|
||||
|
||||
class TensorOpGemm:
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
):
|
||||
self.ab_dtype = ab_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.acc_dtype = acc_dtype
|
||||
self.cta_tiler = (128, 128, 32)
|
||||
self.num_stages = 3
|
||||
self.atom_layout_mnk = atom_layout_mnk
|
||||
atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
|
||||
self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
|
||||
|
||||
self.bM, self.bN, self.bK = self.cta_tiler
|
||||
self.mma_inst_shape = (16, 8, 16)
|
||||
mmaM, mmaN, mmaK = self.mma_inst_shape
|
||||
|
||||
assert (
|
||||
self.bM % (atom_lay_M * mmaM) == 0
|
||||
), "bM must be divisible by MMA instruction"
|
||||
assert (
|
||||
self.bN % (atom_lay_N * mmaN) == 0
|
||||
), "bN must be divisible by MMA instruction"
|
||||
assert atom_lay_K == 1, "this example does not support atom layout K > 1"
|
||||
assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction"
|
||||
assert self.num_stages >= 3, "num_stages must be greater than or equal to 3"
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# The grid divides the problems's M, N, and L dimensions by the
|
||||
# respective modes of the tile shape (bM, bN, 1). The K dimension is
|
||||
# handled within a block via a multistage process.
|
||||
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
|
||||
self.c_major_mode = utils.LayoutEnum.from_tensor(mC)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory layout:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates a layout with the size required for the provided tile
|
||||
# size and num stages (stages are used for K dimension) that is also
|
||||
# sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that
|
||||
# the atom for the shared memory -> register copy does not encounter
|
||||
# bank conflicts
|
||||
|
||||
# assume the input is 16B align
|
||||
ab_copy_bits = 128
|
||||
sA_layout = self._make_smem_layout_AB(
|
||||
mA.element_type,
|
||||
self.a_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[0], self.cta_tiler[2], self.num_stages),
|
||||
)
|
||||
sB_layout = self._make_smem_layout_AB(
|
||||
mB.element_type,
|
||||
self.b_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[1], self.cta_tiler[2], self.num_stages),
|
||||
)
|
||||
|
||||
# Creates a similar layout but without num_stages or layout atoms
|
||||
sC_layout = self._make_smem_layout_C(
|
||||
mC.element_type,
|
||||
self.c_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[0], self.cta_tiler[1]),
|
||||
)
|
||||
|
||||
# Shared memory allocated for operations with A, B will be
|
||||
# overwritten for operations on C. This is to improve performance
|
||||
# by reducing the size of shared memory requested by each block
|
||||
smem_size = max(
|
||||
cute.size_in_bytes(mC.element_type, sC_layout),
|
||||
cute.size_in_bytes(mA.element_type, sA_layout)
|
||||
+ cute.size_in_bytes(mB.element_type, sB_layout),
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tiled copy:
|
||||
# The majorness of tA/tB/tC follows the majorness of gA/gB/gC,
|
||||
# enabling merged accesses to global memory for faster data
|
||||
# transfer between global and shared memory.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Create a copy atom for a global to shared memory asynchronous copy
|
||||
atom_async_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(
|
||||
cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
|
||||
),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=ab_copy_bits,
|
||||
)
|
||||
|
||||
# Create thread layouts for tiled copy from the copy atom where the
|
||||
# thread layout simply follows the leading dimension of the tensor
|
||||
tiled_copy_A = self._make_gmem_tiled_copy_AB(
|
||||
atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits
|
||||
)
|
||||
tiled_copy_B = self._make_gmem_tiled_copy_AB(
|
||||
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
|
||||
)
|
||||
|
||||
# Creates a synchonous copy atom and thread layouts for the epilogue
|
||||
c_copy_bits = 128
|
||||
atom_sync_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
mC.element_type,
|
||||
num_bits_per_copy=c_copy_bits,
|
||||
)
|
||||
tiled_copy_C = self._make_gmem_tiled_copy_C(
|
||||
atom_sync_copy, mC.element_type, self.c_major_mode, c_copy_bits
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tiled MMA
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates a mma atom with 16x8x16 shape for MNK
|
||||
op = cute.nvgpu.warp.MmaF16BF16Op(
|
||||
self.ab_dtype, self.acc_dtype, self.mma_inst_shape
|
||||
)
|
||||
|
||||
permutation_mnk = (
|
||||
self.atom_layout_mnk[0] * self.mma_inst_shape[0],
|
||||
# if atom layout's N-mode is 1, to leverage the largest coalesced
|
||||
# shared memory -> register copy, set the tiled mma's N mode to 16
|
||||
self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
|
||||
self.atom_layout_mnk[2] * self.mma_inst_shape[2],
|
||||
)
|
||||
|
||||
# Created a tiled mma that tiles the atom according to specified layout.
|
||||
# For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice
|
||||
# across M and twice across N
|
||||
tC = cute.make_layout(self.atom_layout_mnk)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
op,
|
||||
tC,
|
||||
permutation_mnk=permutation_mnk,
|
||||
)
|
||||
|
||||
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
|
||||
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
|
||||
self.kernel(
|
||||
mA,
|
||||
mB,
|
||||
mC,
|
||||
sA_layout,
|
||||
sB_layout,
|
||||
sC_layout,
|
||||
tiled_copy_A,
|
||||
tiled_copy_B,
|
||||
tiled_copy_C,
|
||||
tiled_mma,
|
||||
epilogue_op,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[self.num_threads, 1, 1],
|
||||
smem=smem_size,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
sA_layout: cute.ComposedLayout,
|
||||
sB_layout: cute.ComposedLayout,
|
||||
sC_layout: cute.ComposedLayout,
|
||||
tiled_copy_A: cute.TiledCopy,
|
||||
tiled_copy_B: cute.TiledCopy,
|
||||
tiled_copy_C: cute.TiledCopy,
|
||||
tiled_mma: cute.TiledMma,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# Thread index, block index
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
tiler_coord = (bidx, bidy, None)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
gB = cute.local_tile(
|
||||
mB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
|
||||
# By default, if the tensor k mode does not divide into the tile k
|
||||
# size, then last tiles in the k dimension are irregular.
|
||||
# Instead, make the first tiles irregular when k is irregular.
|
||||
# This allows us to handle the irregular tile first to avoid
|
||||
# checking for this condition within the mainloop.
|
||||
|
||||
# residual_k is a negative number indicating the amount needed to
|
||||
# shift the pointer by in dimension k
|
||||
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
|
||||
gA, mode=[2]
|
||||
)
|
||||
|
||||
# move the pointer of gA/gB in the `-k` direction
|
||||
gA = cute.domain_offset((0, residual_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residual_k, 0), gB)
|
||||
# input is 16B aligned
|
||||
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
|
||||
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
|
||||
|
||||
# Construct identity layout for sA and sB (mirrors global tensors,
|
||||
# used for predication only)
|
||||
mcA = cute.make_identity_tensor(mA.layout.shape)
|
||||
mcB = cute.make_identity_tensor(mB.layout.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
|
||||
cA = cute.domain_offset((0, residual_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residual_k, 0), cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffers and get the appropriate fragments for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
sC = cute.make_tensor(
|
||||
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
|
||||
)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
tCsC_epilogue = thr_copy_C.partition_S(sC)
|
||||
tCgC_epilogue = thr_copy_C.partition_D(gC)
|
||||
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
||||
# of tile_shape
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# For predication over the tensors A (M/K), B (N/K), and (in the
|
||||
# epilogue) C (M/N), we will compute it in a fashion similar to an
|
||||
# outer product. The predication along one of the dimensions is
|
||||
# evaluated and stored in a predication tensor. Then, the
|
||||
# predication for the remaining dimension is handled later via an
|
||||
# if/else branch at the copy.
|
||||
# For A and B, predication booleans along M/N are stored in a
|
||||
# predication tensor and along K is handled via a if/else branch.
|
||||
|
||||
# Allocate predicate tensors for M and N. Predication is checked
|
||||
# at the granularity of a copy atom, so the predicate tensor does not
|
||||
# need separate booleans for individual elements within a copy
|
||||
# atom (for example, the elements of tAgA.shape[0][0].)
|
||||
tApA = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAgA.shape[0][1],
|
||||
cute.size(tAgA, mode=[1]),
|
||||
cute.size(tAgA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAgA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for M/N bounds
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tAsA.fill(0)
|
||||
tBsB.fill(0)
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a counting tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
k_tile_index = cutlass.Int32(0)
|
||||
|
||||
for k in range(tApA.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, k, k_tile_index],
|
||||
tAsA[None, None, k, 0],
|
||||
pred=tApA[None, None, k],
|
||||
)
|
||||
for k in range(tBpB.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, k, k_tile_index],
|
||||
tBsB[None, None, k, 0],
|
||||
pred=tBpB[None, None, k],
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# Start async loads for rest of the k-tiles
|
||||
for k_tile in range(1, num_smem_stages - 1):
|
||||
if k_tile == k_tile_count:
|
||||
tApA.fill(0)
|
||||
tBpB.fill(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCsC = thr_mma.partition_C(sC)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Copy Atom A/B retiling
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Create the copy atoms for the copy from shared memory to register
|
||||
atom_copy_s2r_A = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mA.element_type,
|
||||
)
|
||||
atom_copy_s2r_B = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mB.element_type,
|
||||
)
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
||||
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
||||
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
||||
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = 0
|
||||
smem_pipe_write = num_smem_stages - 1
|
||||
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
num_k_block = cute.size(tCrA, mode=[2])
|
||||
if num_k_block > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
# Prefetch the first k-block rmem from the first k-tile
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, 0],
|
||||
tCrA_copy_view[None, None, 0],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, 0],
|
||||
tCrB_copy_view[None, None, 0],
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
|
||||
for k_block in range(num_k_block):
|
||||
if k_block == num_k_block - 1:
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % num_k_block # static
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA_copy_view[None, None, k_block_next],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB_copy_view[None, None, k_block_next],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and compute
|
||||
# instructions, we intentionally use the sequence: copy A, perform GEMM,
|
||||
# then copy B.
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
pred=tApA,
|
||||
)
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == num_smem_stages:
|
||||
smem_pipe_read = 0
|
||||
|
||||
# Sync before epilogue
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue with fusion
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
|
||||
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
|
||||
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create counting tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
cute.size(ceilM) * self.cta_tiler[0],
|
||||
cute.size(ceilN) * self.cta_tiler[1],
|
||||
1,
|
||||
)
|
||||
)
|
||||
cC = cute.local_tile(
|
||||
mcC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
tCcC = thr_copy_C.partition_S(cC)
|
||||
|
||||
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
|
||||
# Wait for all writes to shared memory to finish before starting copies
|
||||
# using the new layouts
|
||||
cute.arch.sync_threads()
|
||||
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
|
||||
|
||||
# Create predication tensor for m
|
||||
tCpC = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tCgC_epilogue.shape[0][1],
|
||||
cute.size(tCgC_epilogue, mode=[1]),
|
||||
cute.size(tCgC_epilogue, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for m in range(tCpC.shape[1]):
|
||||
tCpC[rest_v, m, 0] = cute.elem_less(
|
||||
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
|
||||
)
|
||||
|
||||
# Copy to global memory using better vectorization
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for n in range(tCpC.shape[2]):
|
||||
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
|
||||
cute.copy(
|
||||
tiled_copy_C,
|
||||
tCrC_epilogue[None, None, n],
|
||||
tCgC_epilogue[None, None, n],
|
||||
pred=tCpC[None, None, n],
|
||||
)
|
||||
return
|
||||
|
||||
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
major_mode_size = (
|
||||
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
||||
)
|
||||
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
|
||||
|
||||
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
||||
swizzle_bits = min(swizzle_bits, 3)
|
||||
|
||||
layout_atom_outer = (
|
||||
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
||||
)
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(swizzle_bits, 3, 3),
|
||||
0,
|
||||
layout_atom_outer,
|
||||
)
|
||||
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
|
||||
return layout
|
||||
|
||||
def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
major_mode_size = (
|
||||
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
||||
)
|
||||
|
||||
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
||||
swizzle_bits = min(swizzle_bits, 3)
|
||||
|
||||
layout_atom_outer = (
|
||||
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
||||
)
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(swizzle_bits, 3, 4),
|
||||
0,
|
||||
layout_atom_outer,
|
||||
)
|
||||
|
||||
# Due to the thread layout of the mma, remove swizzle in C to
|
||||
# prevent shared memory fragments owned by an single thread from
|
||||
# holding swizzles
|
||||
if major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(0, 3, 4), 0, layout_atom_outer
|
||||
)
|
||||
layout = cute.tile_to_shape(
|
||||
layout_atom,
|
||||
smem_tiler,
|
||||
(0, 1),
|
||||
)
|
||||
return layout
|
||||
|
||||
def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits):
|
||||
copy_elems = copy_bits // dtype.width
|
||||
shape_dim_1 = cute.size(self.bK) // copy_elems
|
||||
# thread layout for copy
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
||||
)
|
||||
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
||||
shape_dim_0 = cute.size(self.bM) // copy_elems
|
||||
thread_layout = cute.make_layout(
|
||||
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
||||
)
|
||||
# Value layout for copy
|
||||
value_layout = (
|
||||
cute.make_layout((1, copy_elems))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits):
|
||||
copy_elems = copy_bits // dtype.width
|
||||
shape_dim_1 = cute.size(self.bN) // copy_elems
|
||||
# thread layout for copy
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
||||
)
|
||||
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
||||
shape_dim_0 = cute.size(self.bM) // copy_elems
|
||||
thread_layout = cute.make_layout(
|
||||
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
||||
)
|
||||
value_layout = (
|
||||
cute.make_layout((1, copy_elems))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
|
||||
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
|
||||
|
||||
|
||||
def run_tensor_op_gemm(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
problem_shape: Tuple[int, int, int, int],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
):
|
||||
M, N, K, L = problem_shape
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
|
||||
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
||||
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=dtype)
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
|
||||
a = create_and_permute_tensor(
|
||||
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
b = create_and_permute_tensor(
|
||||
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
acc_dtype,
|
||||
atom_layout_mnk,
|
||||
)
|
||||
|
||||
# assume input is 16B aligned
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
|
||||
divisibility=(128 // c_dtype.width),
|
||||
)
|
||||
)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
if not skip_ref_check:
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of multistage block matmul with CuTe on GPU"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mnkl", type=parse_comma_separated_ints, default=(112, 136, 40, 1)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atom_layout_mnk", type=parse_comma_separated_ints, default=(2, 2, 1)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ab_dtype",
|
||||
type=cutlass.dtype,
|
||||
choices=[cutlass.Float16],
|
||||
default=cutlass.Float16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acc_dtype",
|
||||
type=cutlass.dtype,
|
||||
choices=[cutlass.Float32],
|
||||
default=cutlass.Float32,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--c_dtype",
|
||||
type=cutlass.dtype,
|
||||
choices=[cutlass.Float16],
|
||||
default=cutlass.Float16,
|
||||
)
|
||||
parser.add_argument("--a_major", choices=["k", "m"], default="m")
|
||||
parser.add_argument("--b_major", choices=["k", "n"], default="n")
|
||||
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running Ampere tensor core GEMM example:")
|
||||
run_tensor_op_gemm(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.ab_dtype,
|
||||
args.c_dtype,
|
||||
args.acc_dtype,
|
||||
args.mnkl,
|
||||
args.atom_layout_mnk,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
)
|
||||
print("PASS")
|
||||
1922
examples/python/CuTeDSL/blackwell/dense_gemm.py
Normal file
1922
examples/python/CuTeDSL/blackwell/dense_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
2144
examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
Normal file
2144
examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
Normal file
File diff suppressed because it is too large
Load Diff
2984
examples/python/CuTeDSL/blackwell/fmha.py
Normal file
2984
examples/python/CuTeDSL/blackwell/fmha.py
Normal file
File diff suppressed because it is too large
Load Diff
2287
examples/python/CuTeDSL/blackwell/grouped_gemm.py
Normal file
2287
examples/python/CuTeDSL/blackwell/grouped_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
51
examples/python/CuTeDSL/cute/ffi/CMakeLists.txt
Normal file
51
examples/python/CuTeDSL/cute/ffi/CMakeLists.txt
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
project(tensor)
|
||||
|
||||
# Find Python
|
||||
find_package(Python COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
# Get Python site-packages directory using Python
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])"
|
||||
OUTPUT_VARIABLE Python_SITE_PACKAGES
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}")
|
||||
|
||||
# Add nanobind path to CMAKE_PREFIX_PATH
|
||||
list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake)
|
||||
|
||||
# Find nanobind
|
||||
find_package(nanobind REQUIRED)
|
||||
|
||||
# Add the module
|
||||
nanobind_add_module(tensor tensor.cpp)
|
||||
305
examples/python/CuTeDSL/cute/ffi/jit_argument.py
Normal file
305
examples/python/CuTeDSL/cute/ffi/jit_argument.py
Normal file
@ -0,0 +1,305 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
|
||||
|
||||
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
|
||||
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
|
||||
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
|
||||
|
||||
The C-structure is defined as:
|
||||
|
||||
.. code-block:: c
|
||||
|
||||
struct Tensor {
|
||||
void *ptr; // Pointer to tensor data
|
||||
int32_t shape[3]; // Tensor dimensions
|
||||
int32_t strides[3]; // Memory strides for each dimension
|
||||
};
|
||||
|
||||
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
|
||||
shape, and strides, enabling efficient data passing between different language boundaries.
|
||||
|
||||
.. note::
|
||||
Future development may include automated code generation flows.
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
import cutlass._mlir.extras.types as T
|
||||
|
||||
|
||||
class ExampleTensorValue(ir.Value):
|
||||
"""A wrapper class for tensor values in MLIR.
|
||||
|
||||
This class extends ir.Value to provide convenient access to tensor data pointer,
|
||||
shape, and strides through MLIR operations.
|
||||
|
||||
:type: ir.Value
|
||||
"""
|
||||
|
||||
def __init__(self, v):
|
||||
"""Initialize a new TensorValue.
|
||||
|
||||
:param v: The underlying MLIR value to wrap
|
||||
:type v: ir.Value
|
||||
"""
|
||||
super().__init__(v)
|
||||
|
||||
@property
|
||||
def data_ptr(self, *, loc=None, ip=None):
|
||||
"""Get the data pointer from the tensor value.
|
||||
|
||||
Extracts the data pointer (first field) from the LLVM struct value.
|
||||
|
||||
:param loc: Optional location information for MLIR operations
|
||||
:type loc: Optional[ir.Location]
|
||||
:param ip: Optional insertion point for MLIR operations
|
||||
:type ip: Optional[ir.InsertionPoint]
|
||||
:return: An integer value representing the data pointer
|
||||
:rtype: ir.Value
|
||||
"""
|
||||
# Extract the data pointer from the LLVM struct value
|
||||
# The data pointer is the first field (index 0) in the struct
|
||||
|
||||
# Use llvm.extractvalue to get the pointer field from the struct
|
||||
ptr_val = llvm.extractvalue(
|
||||
llvm.PointerType.get(),
|
||||
self,
|
||||
[0], # Extract the first field (index 0)
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
return cute.make_ptr(cutlass.Float32, ptr_val)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Get the shape of the tensor.
|
||||
|
||||
Extracts the shape (second field) from the LLVM struct value.
|
||||
|
||||
:return: A tuple of integers representing the tensor dimensions
|
||||
:rtype: tuple[ir.Value, ...]
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
# Extract the shape field from the LLVM struct value
|
||||
# The shape is the second field (index 1) in the struct
|
||||
shape_val = llvm.extractvalue(
|
||||
llvm.StructType.get_literal([i32_type] * 3),
|
||||
self,
|
||||
[1], # Extract the second field (index 1)
|
||||
)
|
||||
|
||||
# Extract each dimension from the shape struct
|
||||
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""Get the strides of the tensor.
|
||||
|
||||
Extracts the strides (third field) from the LLVM struct value.
|
||||
|
||||
:return: A tuple of integers representing the tensor strides
|
||||
:rtype: tuple[ir.Value, ...]
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
# Extract the strides field from the LLVM struct value
|
||||
# The strides are the third field (index 2) in the struct
|
||||
strides_val = llvm.extractvalue(
|
||||
llvm.StructType.get_literal([i32_type] * 3),
|
||||
self,
|
||||
[2], # Extract the third field (index 2)
|
||||
)
|
||||
|
||||
# Extract each dimension from the strides struct
|
||||
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
|
||||
|
||||
|
||||
class ExampleTensor:
|
||||
"""A class representing a tensor with its data pointer, shape, and strides.
|
||||
|
||||
This class provides a Python interface to create and manipulate tensor structures
|
||||
that can be passed to CUTE JIT compiled functions.
|
||||
|
||||
:ivar _c_struct_p: The C struct pointer for the tensor
|
||||
:ivar _rank: The number of dimensions in the tensor
|
||||
"""
|
||||
|
||||
def __init__(self, c_struct_p, rank):
|
||||
"""Initialize a new Tensor.
|
||||
|
||||
:param c_struct_p: The C struct pointer for the tensor
|
||||
:type c_struct_p: int
|
||||
:param rank: The number of dimensions in the tensor
|
||||
:type rank: int
|
||||
"""
|
||||
self._c_struct_p = c_struct_p
|
||||
self._rank = rank
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
"""Get the MLIR types for this tensor.
|
||||
|
||||
Creates an LLVM structure type representing a C-structure with:
|
||||
|
||||
.. code-block:: c
|
||||
|
||||
struct Tensor {
|
||||
void *ptr;
|
||||
int32_t shape[3];
|
||||
int32_t strides[3];
|
||||
};
|
||||
|
||||
:return: A list containing the MLIR struct type
|
||||
:rtype: list[llvm.StructType]
|
||||
|
||||
Create an LLVM structure type that represents a C-structure like:
|
||||
"""
|
||||
|
||||
# Get the number of dimensions from the shape
|
||||
ndim = self._rank
|
||||
|
||||
# Create the pointer type (void*)
|
||||
ptr_type = llvm.PointerType.get()
|
||||
|
||||
# Create array types for shape and strides (int32_t[ndim])
|
||||
int32_type = ir.IntegerType.get_signless(32)
|
||||
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
|
||||
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
|
||||
|
||||
# Create the structure type
|
||||
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
|
||||
|
||||
return [struct_type]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new TensorValue from MLIR values.
|
||||
|
||||
:param values: A list of MLIR values
|
||||
:type values: list[ir.Value]
|
||||
:return: A new TensorValue instance
|
||||
:rtype: TensorValue
|
||||
"""
|
||||
return ExampleTensorValue(values[0])
|
||||
|
||||
def __c_pointers__(self):
|
||||
"""Get the C pointers for this tensor.
|
||||
|
||||
:return: A list containing the C struct pointer
|
||||
:rtype: list[int]
|
||||
"""
|
||||
return [self._c_struct_p]
|
||||
|
||||
|
||||
@cute.jit
|
||||
def foo(tensor):
|
||||
"""Example JIT function that prints tensor information.
|
||||
|
||||
:param tensor: A Tensor instance to print information about
|
||||
:type tensor: Tensor
|
||||
"""
|
||||
cute.printf("data_ptr: {}", tensor.data_ptr)
|
||||
cute.printf("shape: {}", tensor.shape)
|
||||
cute.printf("stride: {}", tensor.stride)
|
||||
|
||||
mA = cute.make_tensor(
|
||||
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
|
||||
)
|
||||
cute.print_tensor(mA)
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
|
||||
def run_test(tmpdir=None):
|
||||
# Skip cleanup if user provides tmpdir
|
||||
cleanup = tmpdir is None
|
||||
# Initialize temporary build directory
|
||||
tmpdir = tmpdir or tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
|
||||
subprocess.run(["cmake", "--build", tmpdir], check=True)
|
||||
|
||||
sys.path.append(tmpdir)
|
||||
|
||||
from tensor import make_tensor, pycapsule_get_pointer
|
||||
|
||||
# Mock test tensor and corresponding C structure for this example
|
||||
# In production, this may come from external library
|
||||
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
|
||||
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
|
||||
c_struct_p = pycapsule_get_pointer(c_struct)
|
||||
|
||||
# Initialize tensor wrapper and compile test function
|
||||
tensor = ExampleTensor(c_struct_p, len(x.shape))
|
||||
compiled_func = cute.compile(foo, tensor)
|
||||
|
||||
# Benchmark pointer access performance
|
||||
from time import time
|
||||
|
||||
start = time()
|
||||
# Measure performance of critical path pointer access
|
||||
# get C pointers is on critical path to call JIT compiled function
|
||||
for _ in range(1000):
|
||||
tensor.__c_pointers__()
|
||||
end = time()
|
||||
print(f"__c_pointers__: {(end - start) * 1000} us")
|
||||
|
||||
# Execute compiled function
|
||||
compiled_func(tensor)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
if cleanup:
|
||||
# Clean up the temporary directory
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Set temporary directory for building C modules"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_test(args.tmp_dir)
|
||||
82
examples/python/CuTeDSL/cute/ffi/tensor.cpp
Normal file
82
examples/python/CuTeDSL/cute/ffi/tensor.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
|
||||
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
|
||||
// 3. Neither the name of the copyright holder nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
// POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <cstdint>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
// Forward declaration of the MockTensor struct for testing only
|
||||
struct MockTensor {
|
||||
void *ptr;
|
||||
struct {
|
||||
int32_t shape[3];
|
||||
} shape;
|
||||
|
||||
struct {
|
||||
int32_t strides[3];
|
||||
} strides;
|
||||
};
|
||||
|
||||
NB_MODULE(tensor, m) {
|
||||
// create a tensor for testing
|
||||
m.def("make_tensor", [](int64_t ptr, std::vector<int32_t> shape,
|
||||
std::vector<int32_t> strides) {
|
||||
auto *tensor = new MockTensor();
|
||||
tensor->ptr = reinterpret_cast<void *>(ptr);
|
||||
|
||||
assert(shape.size() == 3 && "shape must have 3 elements");
|
||||
assert(strides.size() == 3 && "strides must have 3 elements");
|
||||
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
tensor->shape.shape[i] = shape[i];
|
||||
tensor->strides.strides[i] = strides[i];
|
||||
}
|
||||
|
||||
return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) {
|
||||
auto n = PyCapsule_GetName(capsule);
|
||||
if (void *p = PyCapsule_GetPointer(capsule, n)) {
|
||||
delete reinterpret_cast<MockTensor *>(p);
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
m.def(
|
||||
"pycapsule_get_pointer",
|
||||
[](nb::object &capsule) {
|
||||
void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor");
|
||||
if (!ptr) {
|
||||
throw std::runtime_error("Invalid tensor capsule");
|
||||
}
|
||||
return reinterpret_cast<uintptr_t>(ptr);
|
||||
},
|
||||
"Get pointer from PyCapsule");
|
||||
}
|
||||
1486
examples/python/CuTeDSL/hopper/dense_gemm.py
Normal file
1486
examples/python/CuTeDSL/hopper/dense_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
31
examples/python/CuTeDSL/notebooks/README.md
Normal file
31
examples/python/CuTeDSL/notebooks/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
648
examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb
Normal file
648
examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb
Normal file
@ -0,0 +1,648 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e95f0df-4d1a-4e2e-92ff-90539bb4c517",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Example 06: CUDA Graphs\n",
|
||||
"\n",
|
||||
"In this example we demonstrate how to use CUDA graphs through PyTorch with CuTe DSL.\n",
|
||||
"The process of interacting with PyTorch's CUDA graph implementation requires exposing PyTorch's CUDA streams to CUTLASS.\n",
|
||||
"\n",
|
||||
"To use CUDA graphs with Blackwell requires a version of PyTorch that supports Blackwell.\n",
|
||||
"This can be obtained through:\n",
|
||||
"- The [PyTorch NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)\n",
|
||||
"- [PyTorch 2.7 with CUDA 12.8 or later](https://pytorch.org/) (e.g., `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`)\n",
|
||||
"- Building PyTorch directly with your version of CUDA."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "46b8fb6f-9ac5-4a3d-b765-b6476f182bf7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import torch for CUDA graphs\n",
|
||||
"import torch\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"# import CUstream type from the cuda driver bindings\n",
|
||||
"from cuda.bindings.driver import CUstream\n",
|
||||
"# import the current_stream function from torch\n",
|
||||
"from torch.cuda import current_stream"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bcf5e06e-1f5b-4d72-ad73-9b36efb78ca0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Kernel Creation\n",
|
||||
"\n",
|
||||
"We create a kernel which prints \"Hello world\" as well as a host function to launch the kernel.\n",
|
||||
"We then compile the kernel for use in our graph, by passing in a default stream.\n",
|
||||
"\n",
|
||||
"Kernel compilation before graph capture is required since CUDA graphs cannot JIT compile kernels during graph execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0c2a6ca8-98d7-4837-b91f-af769ca8fcd8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def hello_world_kernel():\n",
|
||||
" \"\"\"\n",
|
||||
" A kernel that prints hello world\n",
|
||||
" \"\"\"\n",
|
||||
" cute.printf(\"Hello world\")\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def hello_world(stream : CUstream):\n",
|
||||
" \"\"\"\n",
|
||||
" Host function that launches our (1,1,1), (1,1,1) grid in stream\n",
|
||||
" \"\"\"\n",
|
||||
" hello_world_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)\n",
|
||||
"\n",
|
||||
"# Grab a stream from PyTorch, this will also initialize our context\n",
|
||||
"# so we can omit cutlass.cuda.initialize_cuda_context()\n",
|
||||
"stream = current_stream()\n",
|
||||
"hello_world_compiled = cute.compile(hello_world, CUstream(stream.cuda_stream))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ecc850af-09f8-4a29-9c93-ff31fbb9326f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating and replaying a CUDA Graph\n",
|
||||
"\n",
|
||||
"We create a stream through torch as well as a graph.\n",
|
||||
"When we create the graph we can pass the stream we want to capture to torch. We similarly run the compiled kernel with the stream passed as a CUstream.\n",
|
||||
"\n",
|
||||
"Finally we can replay our graph and synchronize."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f673e5ae-42bb-44d0-b652-3280606181c4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello world\n",
|
||||
"Hello world\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a CUDA Graph\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"# Capture our graph\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
" # Turn our torch Stream into a cuStream stream.\n",
|
||||
" # This is done by getting the underlying CUstream with .cuda_stream\n",
|
||||
" graph_stream = CUstream(current_stream().cuda_stream)\n",
|
||||
" # Run 2 iterations of our compiled kernel\n",
|
||||
" for _ in range(2):\n",
|
||||
" # Run our kernel in the stream\n",
|
||||
" hello_world_compiled(graph_stream)\n",
|
||||
"\n",
|
||||
"# Replay our graph\n",
|
||||
"g.replay()\n",
|
||||
"# Synchronize all streams (equivalent to cudaDeviceSynchronize() in C++)\n",
|
||||
"torch.cuda.synchronize()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "db76d9c3-7617-4bf2-b326-11982e6803bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our run results in the following execution when viewed in NSight Systems:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"We can observe the launch of the two kernels followed by a `cudaDeviceSynchronize()`.\n",
|
||||
"\n",
|
||||
"Now we can confirm that this minimizes some launch overhead:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3ebe15bf-dc97-42e9-913c-224ecfb472e8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get our CUDA stream from PyTorch\n",
|
||||
"stream = CUstream(current_stream().cuda_stream)\n",
|
||||
"\n",
|
||||
"# Create a larger CUDA Graph of 100 iterations\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"# Capture our graph\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
" # Turn our torch Stream into a cuStream stream.\n",
|
||||
" # This is done by getting the underlying CUstream with .cuda_stream\n",
|
||||
" graph_stream = CUstream(current_stream().cuda_stream)\n",
|
||||
" # Run 2 iterations of our compiled kernel\n",
|
||||
" for _ in range(100):\n",
|
||||
" # Run our kernel in the stream\n",
|
||||
" hello_world_compiled(graph_stream)\n",
|
||||
"\n",
|
||||
"# Create CUDA events for measuring performance\n",
|
||||
"start = torch.cuda.Event(enable_timing=True)\n",
|
||||
"end = torch.cuda.Event(enable_timing=True)\n",
|
||||
"\n",
|
||||
"# Run our kernel to warm up the GPU\n",
|
||||
"for _ in range(100):\n",
|
||||
" hello_world_compiled(stream)\n",
|
||||
"\n",
|
||||
"# Record our start time\n",
|
||||
"start.record()\n",
|
||||
"# Run 100 kernels\n",
|
||||
"for _ in range(100):\n",
|
||||
" hello_world_compiled(stream)\n",
|
||||
"# Record our end time\n",
|
||||
"end.record()\n",
|
||||
"# Synchronize (cudaDeviceSynchronize())\n",
|
||||
"torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
"# Calculate the time spent when launching kernels in a stream\n",
|
||||
"# Results are in ms\n",
|
||||
"stream_time = start.elapsed_time(end) \n",
|
||||
"\n",
|
||||
"# Warmup our GPU again\n",
|
||||
"g.replay()\n",
|
||||
"# Record our start time\n",
|
||||
"start.record()\n",
|
||||
"# Run our graph\n",
|
||||
"g.replay()\n",
|
||||
"# Record our end time\n",
|
||||
"end.record()\n",
|
||||
"# Synchronize (cudaDeviceSynchronize())\n",
|
||||
"torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
"# Calculate the time spent when launching kernels in a graph\n",
|
||||
"# units are ms\n",
|
||||
"graph_time = start.elapsed_time(end)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "12b8151a-46b3-4c99-9945-301f6b628131",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"8.94% speedup when using CUDA graphs for this kernel!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Print out speedup when using CUDA graphs\n",
|
||||
"percent_speedup = (stream_time - graph_time) / graph_time\n",
|
||||
"print(f\"{percent_speedup * 100.0:.2f}% speedup when using CUDA graphs for this kernel!\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
1001
examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb
Normal file
1001
examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
310
examples/python/CuTeDSL/notebooks/data_types.ipynb
Normal file
310
examples/python/CuTeDSL/notebooks/data_types.ipynb
Normal file
@ -0,0 +1,310 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Understanding data structure in CuTe DSL\n",
|
||||
"\n",
|
||||
"In most cases, data structures in CuTe DSL work the same as Python data structures with the notable difference that Python data structures in most cases are considered as static data which are interpreted by the DSL compiler embedded inside Python interpreter.\n",
|
||||
"\n",
|
||||
"To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n",
|
||||
"represent dynamic values in JIT-compiled code.\n",
|
||||
"\n",
|
||||
"CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n",
|
||||
"runtime. These types are formally defined within the CuTe DSL typing system:\n",
|
||||
"\n",
|
||||
"### Integer Types\n",
|
||||
"- `Int8` - 8-bit signed integer\n",
|
||||
"- `Int16` - 16-bit signed integer \n",
|
||||
"- `Int32` - 32-bit signed integer\n",
|
||||
"- `Int64` - 64-bit signed integer\n",
|
||||
"- `Int128` - 128-bit signed integer\n",
|
||||
"- `Uint8` - 8-bit unsigned integer\n",
|
||||
"- `Uint16` - 16-bit unsigned integer\n",
|
||||
"- `Uint32` - 32-bit unsigned integer\n",
|
||||
"- `Uint64` - 64-bit unsigned integer\n",
|
||||
"- `Uint128` - 128-bit unsigned integer\n",
|
||||
"\n",
|
||||
"### Floating Point Types\n",
|
||||
"- `Float16` - 16-bit floating point\n",
|
||||
"- `Float32` - 32-bit floating point \n",
|
||||
"- `Float64` - 64-bit floating point\n",
|
||||
"- `BFloat16` - Brain Floating Point format (16-bit)\n",
|
||||
"- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n",
|
||||
"- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n",
|
||||
"- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n",
|
||||
"\n",
|
||||
"These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n",
|
||||
"evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n",
|
||||
"compilation.\n",
|
||||
"\n",
|
||||
"### Example usage:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"x = cutlass.Int32(5) # Creates a 32-bit integer\n",
|
||||
"y = cutlass.Float32(3.14) # Creates a 32-bit float\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
|
||||
" ...\n",
|
||||
"```\n",
|
||||
"To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n",
|
||||
"represent dynamic values in JIT-compiled code.\n",
|
||||
"\n",
|
||||
"CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n",
|
||||
"runtime. These types are formally defined within the CuTe DSL typing system:\n",
|
||||
"\n",
|
||||
"### Integer Types\n",
|
||||
"- `Int8` - 8-bit signed integer\n",
|
||||
"- `Int16` - 16-bit signed integer \n",
|
||||
"- `Int32` - 32-bit signed integer\n",
|
||||
"- `Int64` - 64-bit signed integer\n",
|
||||
"- `Int128` - 128-bit signed integer\n",
|
||||
"- `Uint8` - 8-bit unsigned integer\n",
|
||||
"- `Uint16` - 16-bit unsigned integer\n",
|
||||
"- `Uint32` - 32-bit unsigned integer\n",
|
||||
"- `Uint64` - 64-bit unsigned integer\n",
|
||||
"- `Uint128` - 128-bit unsigned integer\n",
|
||||
"\n",
|
||||
"### Floating Point Types\n",
|
||||
"- `Float16` - 16-bit floating point\n",
|
||||
"- `Float32` - 32-bit floating point \n",
|
||||
"- `Float64` - 64-bit floating point\n",
|
||||
"- `BFloat16` - Brain Floating Point format (16-bit)\n",
|
||||
"- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n",
|
||||
"- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n",
|
||||
"- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n",
|
||||
"\n",
|
||||
"These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n",
|
||||
"evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n",
|
||||
"compilation.\n",
|
||||
"\n",
|
||||
"### Example usage:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"x = cutlass.Int32(5) # Creates a 32-bit integer\n",
|
||||
"y = cutlass.Float32(3.14) # Creates a 32-bit float\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
|
||||
" ...\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a(static) = ?\n",
|
||||
"b(static) = ?\n",
|
||||
"a(dynamic) = 3.140000\n",
|
||||
"b(dynamic) = 5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def bar():\n",
|
||||
" a = cutlass.Float32(3.14)\n",
|
||||
" print(\"a(static) =\", a) # prints `a(static) = ?`\n",
|
||||
" cute.printf(\"a(dynamic) = {}\", a) # prints `a(dynamic) = 3.140000`\n",
|
||||
"\n",
|
||||
" b = cutlass.Int32(5)\n",
|
||||
" print(\"b(static) =\", b) # prints `b(static) = 5`\n",
|
||||
" cute.printf(\"b(dynamic) = {}\", b) # prints `b(dynamic) = 5`\n",
|
||||
"\n",
|
||||
"bar()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Type Conversion API\n",
|
||||
"\n",
|
||||
"CUTLASS numeric types provide type conversion through the `to()` method available on all Numeric types. This allows you to convert between different numeric data types at runtime.\n",
|
||||
"\n",
|
||||
"Syntax:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"new_value = value.to(target_type)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The `to()` method supports conversion between:\n",
|
||||
"- Integer types (Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64)\n",
|
||||
"- Floating point types (Float16, Float32, Float64, BFloat16)\n",
|
||||
"- Mixed integer/floating point conversions\n",
|
||||
"\n",
|
||||
"Note that when converting from floating point to integer types, the decimal portion is truncated. When converting between types with different ranges, values may be clamped or lose precision if they exceed the target type's representable range."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Int32(42) => Float32(42.000000)\n",
|
||||
"Float32(3.140000) => Int32(3)\n",
|
||||
"Int32(127) => Int8(127)\n",
|
||||
"Int32(300) => Int8(44) (truncated due to range limitation)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def type_conversion():\n",
|
||||
" # Convert from Int32 to Float32\n",
|
||||
" x = cutlass.Int32(42)\n",
|
||||
" y = x.to(cutlass.Float32)\n",
|
||||
" cute.printf(\"Int32({}) => Float32({})\", x, y)\n",
|
||||
"\n",
|
||||
" # Convert from Float32 to Int32\n",
|
||||
" a = cutlass.Float32(3.14)\n",
|
||||
" b = a.to(cutlass.Int32)\n",
|
||||
" cute.printf(\"Float32({}) => Int32({})\", a, b)\n",
|
||||
"\n",
|
||||
" # Convert from Int32 to Int8\n",
|
||||
" c = cutlass.Int32(127)\n",
|
||||
" d = c.to(cutlass.Int8)\n",
|
||||
" cute.printf(\"Int32({}) => Int8({})\", c, d)\n",
|
||||
"\n",
|
||||
" # Convert from Int32 to Int8 with value exceeding Int8 range\n",
|
||||
" e = cutlass.Int32(300)\n",
|
||||
" f = e.to(cutlass.Int8)\n",
|
||||
" cute.printf(\"Int32({}) => Int8({}) (truncated due to range limitation)\", e, f)\n",
|
||||
"\n",
|
||||
"type_conversion()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Operator Overloading\n",
|
||||
"\n",
|
||||
"CUTLASS numeric types support Python's built-in operators, allowing you to write natural mathematical expressions. The operators work with both CUTLASS numeric types and Python native numeric types.\n",
|
||||
"\n",
|
||||
"Supported operators include:\n",
|
||||
"- Arithmetic: `+`, `-`, `*`, `/`, `//`, `%`, `**`\n",
|
||||
"- Comparison: `<`, `<=`, `==`, `!=`, `>=`, `>`\n",
|
||||
"- Bitwise: `&`, `|`, `^`, `<<`, `>>`\n",
|
||||
"- Unary: `-` (negation), `~` (bitwise NOT)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a: Int32(10), b: Int32(3)\n",
|
||||
"x: Float32(5.500000)\n",
|
||||
"\n",
|
||||
"a + b = 13\n",
|
||||
"x * 2 = 11.000000\n",
|
||||
"a + x = 15.500000 (Int32 + Float32 promotes to Float32)\n",
|
||||
"a / b = 3.333333\n",
|
||||
"x / 2.0 = 2.750000\n",
|
||||
"a > b = 1\n",
|
||||
"a & b = 2\n",
|
||||
"-a = -10\n",
|
||||
"~a = -11\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def operator_demo():\n",
|
||||
" # Arithmetic operators\n",
|
||||
" a = cutlass.Int32(10)\n",
|
||||
" b = cutlass.Int32(3)\n",
|
||||
" cute.printf(\"a: Int32({}), b: Int32({})\", a, b)\n",
|
||||
"\n",
|
||||
" x = cutlass.Float32(5.5)\n",
|
||||
" cute.printf(\"x: Float32({})\", x)\n",
|
||||
"\n",
|
||||
" cute.printf(\"\")\n",
|
||||
"\n",
|
||||
" sum_result = a + b\n",
|
||||
" cute.printf(\"a + b = {}\", sum_result)\n",
|
||||
"\n",
|
||||
" y = x * 2 # Multiplying with Python native type\n",
|
||||
" cute.printf(\"x * 2 = {}\", y)\n",
|
||||
"\n",
|
||||
" # Mixed type arithmetic (Int32 + Float32) that integer is converted into float32\n",
|
||||
" mixed_result = a + x\n",
|
||||
" cute.printf(\"a + x = {} (Int32 + Float32 promotes to Float32)\", mixed_result)\n",
|
||||
"\n",
|
||||
" # Division with Int32 (note: integer division)\n",
|
||||
" div_result = a / b\n",
|
||||
" cute.printf(\"a / b = {}\", div_result)\n",
|
||||
"\n",
|
||||
" # Float division\n",
|
||||
" float_div = x / cutlass.Float32(2.0)\n",
|
||||
" cute.printf(\"x / 2.0 = {}\", float_div)\n",
|
||||
"\n",
|
||||
" # Comparison operators\n",
|
||||
" is_greater = a > b\n",
|
||||
" cute.printf(\"a > b = {}\", is_greater)\n",
|
||||
"\n",
|
||||
" # Bitwise operators\n",
|
||||
" bit_and = a & b\n",
|
||||
" cute.printf(\"a & b = {}\", bit_and)\n",
|
||||
"\n",
|
||||
" neg_a = -a\n",
|
||||
" cute.printf(\"-a = {}\", neg_a)\n",
|
||||
"\n",
|
||||
" not_a = ~a\n",
|
||||
" cute.printf(\"~a = {}\", not_a)\n",
|
||||
"\n",
|
||||
"operator_demo()\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
838
examples/python/CuTeDSL/notebooks/elementwise_add.ipynb
Normal file
838
examples/python/CuTeDSL/notebooks/elementwise_add.ipynb
Normal file
@ -0,0 +1,838 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"from cutlass.cute.runtime import from_dlpack"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial: Elementwise Add Kernel in CuTe DSL\n",
|
||||
"\n",
|
||||
"This tutorial demonstrates how to implement a simple elementwise\n",
|
||||
"addition kernel using the CuTe DSL (Domain Specific Language).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Elementwise Addition\n",
|
||||
"---------------------\n",
|
||||
"\n",
|
||||
"Elementwise addition is a fundamental operation in linear algebra.\n",
|
||||
"Given two tensors of the same shape, the operation performs element-wise\n",
|
||||
"addition to produce a result tensor of the same shape.\n",
|
||||
"\n",
|
||||
"For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,\n",
|
||||
"the elementwise addition operation :math:`C = A + B` is defined as:\n",
|
||||
"\n",
|
||||
"$\n",
|
||||
" C_{i,j} = A_{i,j} + B_{i,j}\n",
|
||||
"$\n",
|
||||
"\n",
|
||||
"where:\n",
|
||||
"\n",
|
||||
"- $i \\in [0, M-1]$ represents the row index\n",
|
||||
"- $j \\in [0, N-1]$ represents the column index\n",
|
||||
"- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ \n",
|
||||
" in tensors $A$, $B$, and $C$ respectively\n",
|
||||
"\n",
|
||||
"This operation is performed independently for each element position,\n",
|
||||
"making it highly parallelizable and well-suited for GPU implementation.\n",
|
||||
"\n",
|
||||
"Naive Elementwise Add Kernel\n",
|
||||
"-----------------------------\n",
|
||||
"\n",
|
||||
"Let's start with a naive implementation that loads each element from\n",
|
||||
"$A$ and $B$, adds them, and stores the result back to $C$."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def naive_elementwise_add_kernel(\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
" bdim, _, _ = cute.arch.block_dim()\n",
|
||||
"\n",
|
||||
" thread_idx = bidx * bdim + tidx\n",
|
||||
"\n",
|
||||
" # Map thread index to logical index of input tensor\n",
|
||||
" m, n = gA.shape\n",
|
||||
" ni = thread_idx % n\n",
|
||||
" mi = thread_idx // n\n",
|
||||
"\n",
|
||||
" # Map logical index to physical address via tensor layout\n",
|
||||
" a_val = gA[mi, ni]\n",
|
||||
" b_val = gB[mi, ni]\n",
|
||||
"\n",
|
||||
" # Perform element-wise addition\n",
|
||||
" gC[mi, ni] = a_val + b_val"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Structure of the Kernel\n",
|
||||
"\n",
|
||||
"The naive kernel simply maps each thread to one element with a 1-to-1 mapping.\n",
|
||||
"In this kernel, we don't use CuTe layout algebra but only use basic\n",
|
||||
"addressing to index the tensor.\n",
|
||||
"\n",
|
||||
"We can launch the kernel with the following JIT function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def naive_elementwise_add(\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor\n",
|
||||
"):\n",
|
||||
" num_threads_per_block = 256\n",
|
||||
"\n",
|
||||
" m, n = mA.shape\n",
|
||||
" kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
|
||||
" kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),\n",
|
||||
" block=(num_threads_per_block, 1, 1))\n",
|
||||
"\n",
|
||||
"M, N = 2048, 2048\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"# Compile kernel\n",
|
||||
"naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n",
|
||||
"naive_elementwise_add_(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Benchmark performance\n",
|
||||
"\n",
|
||||
"Here's a utility function to benchmark our kernel implementations:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(callable, *, num_warmups, num_iterations):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
"\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
" for _ in range(num_warmups):\n",
|
||||
" callable()\n",
|
||||
"\n",
|
||||
" start_event.record(stream=torch.cuda.current_stream())\n",
|
||||
" for _ in range(num_iterations):\n",
|
||||
" callable()\n",
|
||||
" end_event.record(stream=torch.cuda.current_stream())\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
" elapsed_time = start_event.elapsed_time(end_event)\n",
|
||||
" avg_time = elapsed_time / num_iterations\n",
|
||||
"\n",
|
||||
" print(f\"Average execution time: {avg_time:.4f} ms\")\n",
|
||||
" print(f\"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Average execution time: 0.0385 ms\n",
|
||||
"Throughput: 653.44 GB/s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Performance Analysis\n",
|
||||
"\n",
|
||||
"While our naive implementation maps thread indices to contiguous tensor\n",
|
||||
"dimensions for coalesced memory access, it doesn't have enough\n",
|
||||
"in-flight load & store operations to hide memory latency.\n",
|
||||
"\n",
|
||||
"According to Little's Law:\n",
|
||||
"\n",
|
||||
"$ L = \\lambda \\times W $\n",
|
||||
"\n",
|
||||
"Where:\n",
|
||||
"- $L$ is the average number of items in a system\n",
|
||||
"- $\\lambda$ is the average arrival rate of items (bandwidth)\n",
|
||||
"- $W$ is the average time an item spends in the system (latency)\n",
|
||||
"\n",
|
||||
"For our elementwise addition kernel:\n",
|
||||
"\n",
|
||||
"1. $L$: The number of load & store operations in-flight\n",
|
||||
"2. $\\lambda$ (Bandwidth): Data transfer rate between memory and compute units\n",
|
||||
"3. $W$ (Latency): Round-trip delay of memory requests\n",
|
||||
"\n",
|
||||
"For memory-bound operations like elementwise addition, performance is\n",
|
||||
"limited by the number of in-flight load & store operations.\n",
|
||||
"\n",
|
||||
"## Vectorized Load and Store\n",
|
||||
"\n",
|
||||
"To improve performance according to Little's Law, we need to increase the number\n",
|
||||
"of in-flight requests. We can do this by increasing the number of bytes handled\n",
|
||||
"in each load & store operation per thread through vectorized memory access.\n",
|
||||
"\n",
|
||||
"Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,\n",
|
||||
"we can load 4 elements per vectorized operation on contiguous rows.\n",
|
||||
"CuTe tiling operations make this vectorization straightforward.\n",
|
||||
"\n",
|
||||
"Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
|
||||
"``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
|
||||
"as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).\n",
|
||||
"Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"mA : cute.Tensor # (2048,2048):(2048,1)\n",
|
||||
"gA = cute.zipped_divide(a, tiler=(1, 4)) # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"$\n",
|
||||
" \\begin{array}{ccccc}\n",
|
||||
" & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\\\\n",
|
||||
" & \\underbrace{\\phantom{(1,4)}}_{tiler} & & \\underbrace{\\phantom{(2048,512)}}_{threads} & \\\\\n",
|
||||
" & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
|
||||
" \\end{array}\n",
|
||||
"$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def vectorized_elementwise_add_kernel(\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
" bdim, _, _ = cute.arch.block_dim()\n",
|
||||
"\n",
|
||||
" thread_idx = bidx * bdim + tidx\n",
|
||||
"\n",
|
||||
" # Map thread index to logical index of input tensor\n",
|
||||
" m, n = gA.shape[1] # thread-domain\n",
|
||||
" ni = thread_idx % n\n",
|
||||
" mi = thread_idx // n\n",
|
||||
"\n",
|
||||
" # Map logical index to physical address via tensor layout\n",
|
||||
" a_val = gA[(None, (mi, ni))].load()\n",
|
||||
" b_val = gB[(None, (mi, ni))].load()\n",
|
||||
" print(f\"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}\")\n",
|
||||
" print(f\"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}\")\n",
|
||||
"\n",
|
||||
" # Perform element-wise addition\n",
|
||||
" gC[(None, (mi, ni))] = a_val + b_val"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
|
||||
"with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
|
||||
"we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"gA[(None, (mi, ni))]\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Then tensor data can be loaded into vector via the `.load()` method.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
" slice\n",
|
||||
" ((1,4),(2048,512)):((0,1),(2048,4)) ==> ((1,4)):((0,1))\n",
|
||||
" ^ ^ ^\n",
|
||||
" | | |\n",
|
||||
" (None, (mi, ni))\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[DSL INFO] Tiled Tensors:\n",
|
||||
"[DSL INFO] gA = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
|
||||
"[DSL INFO] gB = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
|
||||
"[DSL INFO] gC = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
|
||||
"[DSL INFO] sliced gA = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n",
|
||||
"[DSL INFO] sliced gB = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def vectorized_elementwise_add(\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor\n",
|
||||
"):\n",
|
||||
" threads_per_block = 256\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, (1, 4))\n",
|
||||
" gB = cute.zipped_divide(mB, (1, 4))\n",
|
||||
" gC = cute.zipped_divide(mC, (1, 4))\n",
|
||||
"\n",
|
||||
" print(f\"[DSL INFO] Tiled Tensors:\")\n",
|
||||
" print(f\"[DSL INFO] gA = {gA}\")\n",
|
||||
" print(f\"[DSL INFO] gB = {gB}\")\n",
|
||||
" print(f\"[DSL INFO] gC = {gC}\")\n",
|
||||
"\n",
|
||||
" vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
|
||||
" grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
|
||||
" block=(threads_per_block, 1, 1),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
|
||||
"compiled_func(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Average execution time: 0.0202 ms\n",
|
||||
"Throughput: 1244.98 GB/s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## TV Layout\n",
|
||||
"\n",
|
||||
"Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
|
||||
"to physical addresses:\n",
|
||||
"\n",
|
||||
"Step 1: Map thread index to logical M/N coordinates\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" mi = thread_idx // n\n",
|
||||
" ni = thread_idx % n\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Step 2: Map logical M/N coordinates to physical addresses using the tensor layout\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" a[(None, (mi, ni))].load()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"CuTe uses TV layout to represent this mapping from thread index and value index\n",
|
||||
"(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
|
||||
"By configuring different TV layouts, we can experiment with different memory access\n",
|
||||
"patterns with minimal code changes.\n",
|
||||
"\n",
|
||||
"The following example demonstrates two levels of tiling: at the thread-block level\n",
|
||||
"and at the thread level.\n",
|
||||
"\n",
|
||||
"For thread-block level tiling, each input & output tensor is first divided\n",
|
||||
"into a group of ``(TileM, TileN)`` sub-tensors at the host side.\n",
|
||||
"\n",
|
||||
"Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor\n",
|
||||
"(``gA[((None, None), bidx)]``), which returns a thread-block local view of\n",
|
||||
"a single ``(TileM, TileN)`` sub-tensor.\n",
|
||||
"\n",
|
||||
"For thread level tiling, we compose the sub-tensor (which maps from logical coordinates\n",
|
||||
"to physical addresses) with the TV layout (which maps from thread & value indices to\n",
|
||||
"logical coordinates). This gives us a tiled sub-tensor that maps from thread & value\n",
|
||||
"indices directly to physical addresses.\n",
|
||||
"\n",
|
||||
"We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)\n",
|
||||
"to get a thread-local view of the data each thread accesses. Note that the thread index\n",
|
||||
"is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def elementwise_add_kernel(\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
" tv_layout: cute.Layout\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # slice for thread-block level view\n",
|
||||
" #--------------------------------\n",
|
||||
" blk_coord = ((None, None), bidx)\n",
|
||||
"\n",
|
||||
" # logical coord -> address\n",
|
||||
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # compose for thread-index & value-index to physical mapping\n",
|
||||
" #--------------------------------\n",
|
||||
" # blockA: (TileM, TileN) -> physical address\n",
|
||||
" # tv_layout: (tid, vid) -> (TileM, TileN)\n",
|
||||
" # tidfrgA = blkA o tv_layout\n",
|
||||
" # tidfrgA: (tid, vid) -> physical address\n",
|
||||
" tidfrgA = cute.composition(blkA, tv_layout)\n",
|
||||
" tidfrgB = cute.composition(blkB, tv_layout)\n",
|
||||
" tidfrgC = cute.composition(blkC, tv_layout)\n",
|
||||
"\n",
|
||||
" print(f\"Composed with TV layout:\")\n",
|
||||
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # slice for thread-level view\n",
|
||||
" #--------------------------------\n",
|
||||
" # `None` represent slice of the entire per-thread data\n",
|
||||
" thr_coord = (tidx, None)\n",
|
||||
"\n",
|
||||
" # slice for threads: vid -> address\n",
|
||||
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
|
||||
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
|
||||
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
|
||||
"\n",
|
||||
" thrC[None] = thrA.load() + thrB.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we take a closer look at the layout of zipped divided input tensor `gA`:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Tiled to Thread Block:\n",
|
||||
"\n",
|
||||
" ((16,256),(128,8)) : ((2048,1),(32768,256))\n",
|
||||
" ~~~~~~~~ ~~~~~~ ~~~~~~~~\n",
|
||||
" | | |\n",
|
||||
" | | |\n",
|
||||
" | `------------------------> Number of Thread Blocks\n",
|
||||
" | |\n",
|
||||
" | |\n",
|
||||
" `--------------------'\n",
|
||||
" |\n",
|
||||
" V\n",
|
||||
" Thread Block\n",
|
||||
" Tile\n",
|
||||
"\n",
|
||||
"Sliced to Thread-Block local sub-tensor (a (16, 256) tile): gA[((None, None), bidx)]\n",
|
||||
"\n",
|
||||
" (16,256) : (2048,1)\n",
|
||||
" ~~~~~~ ~~~~~~\n",
|
||||
" | | Tiled/Composed with TV Layout\n",
|
||||
" | | \n",
|
||||
" | | o ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
" V V \n",
|
||||
"~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~ \n",
|
||||
"((32,4), (8,4)) : ((4,8192),(1,2048))\n",
|
||||
" | |\n",
|
||||
" | `--------> per thread fragment\n",
|
||||
" |\n",
|
||||
"Thread Block\n",
|
||||
" Shape\n",
|
||||
"\n",
|
||||
"Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The host code below shows the construction of the TV layout. By composing\n",
|
||||
"a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,\n",
|
||||
"then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads\n",
|
||||
"8 contiguous elements on the row dimension across 4 contiguous rows),\n",
|
||||
"we obtain the TV layout shown in the figure above."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tiler: (16, 256)\n",
|
||||
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
"Tiled Input Tensors:\n",
|
||||
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
"Composed with TV layout:\n",
|
||||
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def elementwise_add(\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" # mA layout: (M, N):(N, 1)\n",
|
||||
" # TV layout map thread & value index to (16, 256) logical tile\n",
|
||||
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
|
||||
" # mode-1 for coalesced load-store\n",
|
||||
" # - each thread load 8 contiguous element each row and load 4 rows\n",
|
||||
" thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
|
||||
" val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
|
||||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||||
" print(f\"Tiler: {tiler_mn}\")\n",
|
||||
" print(f\"TV Layout: {tv_layout}\")\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
"\n",
|
||||
" print(f\"Tiled Input Tensors:\")\n",
|
||||
" print(f\" gA: {gA.type}\")\n",
|
||||
" print(f\" gB: {gB.type}\")\n",
|
||||
" print(f\" gC: {gC.type}\")\n",
|
||||
"\n",
|
||||
" # Launch the kernel asynchronously\n",
|
||||
" # Async token(s) can also be specified as dependencies\n",
|
||||
" elementwise_add_kernel(\n",
|
||||
" gA, gB, gC, tv_layout\n",
|
||||
" ).launch(\n",
|
||||
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
||||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n",
|
||||
"elementwise_add_(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Average execution time: 0.0222 ms\n",
|
||||
"Throughput: 1133.58 GB/s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using Lambda Function\n",
|
||||
"\n",
|
||||
"CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n",
|
||||
"E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"@cute.jit\n",
|
||||
"def elementwise_apply(\n",
|
||||
" op: cutlass.Constexpr,\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor\n",
|
||||
"):\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tiler: (16, 256)\n",
|
||||
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
"Tiled Input Tensors:\n",
|
||||
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
"Composed with TV layout:\n",
|
||||
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def elementwise_apply_kernel(\n",
|
||||
" op: cutlass.Constexpr, # lambda function must be const expr to generate code at compile time\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
" tv_layout: cute.Layout\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
"\n",
|
||||
" blk_coord = ((None, None), bidx)\n",
|
||||
"\n",
|
||||
" # logical coord -> address\n",
|
||||
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
"\n",
|
||||
" tidfrgA = cute.composition(blkA, tv_layout)\n",
|
||||
" tidfrgB = cute.composition(blkB, tv_layout)\n",
|
||||
" tidfrgC = cute.composition(blkC, tv_layout)\n",
|
||||
"\n",
|
||||
" print(f\"Composed with TV layout:\")\n",
|
||||
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
|
||||
"\n",
|
||||
" thr_coord = (tidx, None)\n",
|
||||
"\n",
|
||||
" # slice for threads: vid -> address\n",
|
||||
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
|
||||
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
|
||||
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # apply custom operation\n",
|
||||
" #--------------------------------\n",
|
||||
" thrC[None] = op(thrA.load(), thrB.load())\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def elementwise_op(\n",
|
||||
" op: cutlass.Constexpr,\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" # mA layout: (M, N):(N, 1)\n",
|
||||
" # TV layout map thread & value index to (16, 256) logical tile\n",
|
||||
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
|
||||
" # mode-1 for coalesced load-store\n",
|
||||
" # - each thread load 8 contiguous element each row and load 4 rows\n",
|
||||
" thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
|
||||
" val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
|
||||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||||
" print(f\"Tiler: {tiler_mn}\")\n",
|
||||
" print(f\"TV Layout: {tv_layout}\")\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
"\n",
|
||||
" print(f\"Tiled Input Tensors:\")\n",
|
||||
" print(f\" gA: {gA.type}\")\n",
|
||||
" print(f\" gB: {gB.type}\")\n",
|
||||
" print(f\" gC: {gC.type}\")\n",
|
||||
"\n",
|
||||
" # Launch the kernel asynchronously\n",
|
||||
" # Async token(s) can also be specified as dependencies\n",
|
||||
" elementwise_apply_kernel(\n",
|
||||
" op, gA, gB, gC, tv_layout\n",
|
||||
" ).launch(\n",
|
||||
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
||||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"from operator import mul\n",
|
||||
"\n",
|
||||
"elementwise_op(mul, a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, mul(a, b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Custom operators can be more complex. For example, here's a function that performs\n",
|
||||
"multiplication followed by ReLU:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tiler: (16, 256)\n",
|
||||
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
"Tiled Input Tensors:\n",
|
||||
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
"Composed with TV layout:\n",
|
||||
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def mul_relu(a, b):\n",
|
||||
" tmp = a * b\n",
|
||||
" return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# As we uses cute.where in customized operation, we need to create another relu function\n",
|
||||
"def mul_relu_ref(a, b):\n",
|
||||
" tmp = a * b\n",
|
||||
" return torch.relu(tmp)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"elementwise_op(mul_relu, a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, mul_relu_ref(a, b))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"state": {},
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
174
examples/python/CuTeDSL/notebooks/hello_world.ipynb
Normal file
174
examples/python/CuTeDSL/notebooks/hello_world.ipynb
Normal file
@ -0,0 +1,174 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Your First Program with CuTe DSL\n",
|
||||
"\n",
|
||||
"## Introduction\n",
|
||||
"\n",
|
||||
"Welcome! In this tutorial, we'll write a simple \"Hello World\" program that runs on your GPU using CuTe DSL. This will help you understand the basics of GPU programming with our framework.\n",
|
||||
"\n",
|
||||
"### What You'll Learn\n",
|
||||
"\n",
|
||||
"- How to write code that runs on both CPU (host) and GPU (device),\n",
|
||||
"- How to launch a GPU kernel (a function that runs on the GPU),\n",
|
||||
"- Basic CUDA concepts like threads and thread blocks,\n",
|
||||
"\n",
|
||||
"### Step 1: Import Required Libraries\n",
|
||||
"\n",
|
||||
"First, let's import the libraries we need:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass \n",
|
||||
"import cutlass.cute as cute "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Step 2: Write Our GPU Kernel\n",
|
||||
"A GPU kernel is a function that runs on the GPU. Here's a simple kernel that prints \"Hello World\".\n",
|
||||
"Key concepts:\n",
|
||||
"- `@cute.kernel`: This decorator tells CUTLASS that this function should run on the GPU\n",
|
||||
"- `cute.arch.thread_idx()`: Gets the ID of the current GPU thread (like a worker's ID number)\n",
|
||||
"- We only want one thread to print the message (thread 0) to avoid multiple prints"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def kernel():\n",
|
||||
" # Get the x component of the thread index (y and z components are unused)\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" # Only the first thread (thread 0) prints the message\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.printf(\"Hello world\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Step 3: Write Our Host Function\n",
|
||||
"\n",
|
||||
"Now we need a function that sets up the GPU and launches our kernel.\n",
|
||||
"Key concepts:\n",
|
||||
"- `@cute.jit`: This decorator is for functions that run on the CPU but can launch GPU code\n",
|
||||
"- We need to initialize CUDA before using the GPU\n",
|
||||
"- `.launch()` tells CUDA how many blocks, threads, shared memory, etc. to use"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def hello_world():\n",
|
||||
"\n",
|
||||
" # Print hello world from host code\n",
|
||||
" cute.printf(\"hello world\")\n",
|
||||
"\n",
|
||||
" # Launch kernel\n",
|
||||
" kernel().launch(\n",
|
||||
" grid=(1, 1, 1), # Single thread block\n",
|
||||
" block=(32, 1, 1) # One warp (32 threads) per thread block\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Step 4: Run Our Program\n",
|
||||
"\n",
|
||||
"There are 2 ways we can run our program:\n",
|
||||
"\n",
|
||||
"1. compile and run immediately\n",
|
||||
"2. separate compilation which allows us to compile the code once and run multiple times\n",
|
||||
" \n",
|
||||
"Please note the `Compiling...` for Method 2 prints before the \"Hello world\" of the first kernel. This shows the asynchronous behavior between CPU and GPU prints. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running hello_world()...\n",
|
||||
"hello world\n",
|
||||
"Compiling...\n",
|
||||
"Hello world\n",
|
||||
"Running compiled version...\n",
|
||||
"hello world\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Initialize CUDA context for launching a kernel with error checking\n",
|
||||
"# We make context initialization explicit to allow users to control the context creation \n",
|
||||
"# and avoid potential issues with multiple contexts\n",
|
||||
"cutlass.cuda.initialize_cuda_context()\n",
|
||||
"\n",
|
||||
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
|
||||
"print(\"Running hello_world()...\")\n",
|
||||
"hello_world()\n",
|
||||
"\n",
|
||||
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
|
||||
"print(\"Compiling...\")\n",
|
||||
"hello_world_compiled = cute.compile(hello_world)\n",
|
||||
"\n",
|
||||
"# Run the pre-compiled version\n",
|
||||
"print(\"Running compiled version...\")\n",
|
||||
"hello_world_compiled()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"state": {},
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
BIN
examples/python/CuTeDSL/notebooks/images/cuda_graphs_image.png
Normal file
BIN
examples/python/CuTeDSL/notebooks/images/cuda_graphs_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.4 KiB |
425
examples/python/CuTeDSL/notebooks/print.ipynb
Normal file
425
examples/python/CuTeDSL/notebooks/print.ipynb
Normal file
@ -0,0 +1,425 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Printing with CuTe DSL\n",
|
||||
"\n",
|
||||
"This notebook demonstrates the different ways to print values in CuTe and explains the important distinction between static (compile-time) and dynamic (runtime) values.\n",
|
||||
"\n",
|
||||
"## Key Concepts\n",
|
||||
"- Static values: Known at compile time\n",
|
||||
"- Dynamic values: Only known at runtime\n",
|
||||
"- Different printing methods for different scenarios\n",
|
||||
"- Layout representation in CuTe\n",
|
||||
"- Tensor visualization and formatting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Print Example Function\n",
|
||||
"\n",
|
||||
"The `print_example` function demonstrates several important concepts:\n",
|
||||
"\n",
|
||||
"### 1. Python's `print` vs CuTe's `cute.printf`\n",
|
||||
"- `print`: Can only show static values at compile time\n",
|
||||
"- `cute.printf`: Can display both static and dynamic values at runtime\n",
|
||||
"\n",
|
||||
"### 2. Value Types\n",
|
||||
"- `a`: Dynamic `Int32` value (runtime)\n",
|
||||
"- `b`: Static `Constexpr[int]` value (compile-time)\n",
|
||||
"\n",
|
||||
"### 3. Layout Printing\n",
|
||||
"Shows how layouts are represented differently in static vs dynamic contexts:\n",
|
||||
"- Static context: Unknown values shown as `?`\n",
|
||||
"- Dynamic context: Actual values displayed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def print_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
|
||||
" \"\"\"\n",
|
||||
" Demonstrates different printing methods in CuTe and how they handle static vs dynamic values.\n",
|
||||
"\n",
|
||||
" This example shows:\n",
|
||||
" 1. How Python's `print` function works with static values at compile time but can't show dynamic values\n",
|
||||
" 2. How `cute.printf` can display both static and dynamic values at runtime\n",
|
||||
" 3. The difference between types in static vs dynamic contexts\n",
|
||||
" 4. How layouts are represented in both printing methods\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" a: A dynamic Int32 value that will be determined at runtime\n",
|
||||
" b: A static (compile-time constant) integer value\n",
|
||||
" \"\"\"\n",
|
||||
" # Use Python `print` to print static information\n",
|
||||
" print(\">>>\", b) # => 2\n",
|
||||
" # `a` is dynamic value\n",
|
||||
" print(\">>>\", a) # => ?\n",
|
||||
"\n",
|
||||
" # Use `cute.printf` to print dynamic information\n",
|
||||
" cute.printf(\">?? {}\", a) # => 8\n",
|
||||
" cute.printf(\">?? {}\", b) # => 2\n",
|
||||
"\n",
|
||||
" print(\">>>\", type(a)) # => <class 'cutlass.Int32'>\n",
|
||||
" print(\">>>\", type(b)) # => <class 'int'>\n",
|
||||
"\n",
|
||||
" layout = cute.make_layout((a, b))\n",
|
||||
" print(\">>>\", layout) # => (?,2):(1,?)\n",
|
||||
" cute.printf(\">?? {}\", layout) # => (8,2):(1,8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compile and Run\n",
|
||||
"\n",
|
||||
"**Direct Compilation and Run**\n",
|
||||
" - `print_example(cutlass.Int32(8), 2)`\n",
|
||||
" - Compiles and runs in one step will execute both static and dynamic print\n",
|
||||
" * `>>>` stands for static print\n",
|
||||
" * `>??` stands for dynamic print"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">>> 2\n",
|
||||
">>> ?\n",
|
||||
">>> Int32\n",
|
||||
">>> <class 'int'>\n",
|
||||
">>> (?,2):(1,?)\n",
|
||||
">?? 8\n",
|
||||
">?? 2\n",
|
||||
">?? (8,2):(1,8)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print_example(cutlass.Int32(8), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compile Function\n",
|
||||
"\n",
|
||||
"When compiles the function with `cute.compile(print_example, cutlass.Int32(8), 2)`, Python interpreter \n",
|
||||
"traces code and only evaluate static expression and print static information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">>> 2\n",
|
||||
">>> ?\n",
|
||||
">>> Int32\n",
|
||||
">>> <class 'int'>\n",
|
||||
">>> (?,2):(1,?)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print_example_compiled = cute.compile(print_example, cutlass.Int32(8), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Call compiled function\n",
|
||||
"\n",
|
||||
"Only print out runtime information"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">?? 8\n",
|
||||
">?? 2\n",
|
||||
">?? (8,2):(1,8)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print_example_compiled(cutlass.Int32(8))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Format String Example\n",
|
||||
"\n",
|
||||
"The `format_string_example` function shows an important limitation:\n",
|
||||
"- F-strings in CuTe are evaluated at compile time\n",
|
||||
"- This means dynamic values won't show their runtime values in f-strings\n",
|
||||
"- Use `cute.printf` when you need to see runtime values"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Direct run output:\n",
|
||||
"a: ?, b: 2\n",
|
||||
"layout: (?,2):(1,?)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def format_string_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
|
||||
" \"\"\"\n",
|
||||
" Format string is evaluated at compile time.\n",
|
||||
" \"\"\"\n",
|
||||
" print(f\"a: {a}, b: {b}\")\n",
|
||||
"\n",
|
||||
" layout = cute.make_layout((a, b))\n",
|
||||
" print(f\"layout: {layout}\")\n",
|
||||
"\n",
|
||||
"print(\"Direct run output:\")\n",
|
||||
"format_string_example(cutlass.Int32(8), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Printing Tensor Examples\n",
|
||||
"\n",
|
||||
"CuTe provides specialized functionality for printing tensors through the `print_tensor` operation. The `cute.print_tensor` takes the following parameter:\n",
|
||||
"- `Tensor` (required): A CuTe tensor object that you want to print. The tensor must support load and store operations\n",
|
||||
"- `verbose` (optional, default=False): A boolean flag that controls the level of detail in the output. When set to True, it will print indices details for each element in the tensor.\n",
|
||||
"\n",
|
||||
"Below example code shows the difference between verbose ON and OFF, and how to print a sub range of the given tensor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass.cute.runtime import from_dlpack\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_basic(x : cute.Tensor):\n",
|
||||
" # Print the tensor\n",
|
||||
" print(\"Basic output:\")\n",
|
||||
" cute.print_tensor(x)\n",
|
||||
" \n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_verbose(x : cute.Tensor):\n",
|
||||
" # Print the tensor with verbose mode\n",
|
||||
" print(\"Verbose output:\")\n",
|
||||
" cute.print_tensor(x, verbose=True)\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_slice(x : cute.Tensor, coord : tuple):\n",
|
||||
" # slice a 2D tensor from the 3D tensor\n",
|
||||
" sliced_data = cute.slice_(x, coord)\n",
|
||||
" y = cute.make_fragment(sliced_data.layout, sliced_data.element_type)\n",
|
||||
" # Convert to TensorSSA format by loading the sliced data into the fragment\n",
|
||||
" y.store(sliced_data.load())\n",
|
||||
" print(\"Slice output:\")\n",
|
||||
" cute.print_tensor(y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The default `cute.print_tensor` will output CuTe tensor with datatype, storage space, CuTe layout information, and print data in torch-style format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Basic output:\n",
|
||||
"tensor(raw_ptr(0x000000000a5f1d50: f32, generic, align<4>) o (4,3,2):(6,2,1), data=\n",
|
||||
" [[[ 0.000000, 2.000000, 4.000000, ],\n",
|
||||
" [ 6.000000, 8.000000, 10.000000, ],\n",
|
||||
" [ 12.000000, 14.000000, 16.000000, ],\n",
|
||||
" [ 18.000000, 20.000000, 22.000000, ]],\n",
|
||||
"\n",
|
||||
" [[ 1.000000, 3.000000, 5.000000, ],\n",
|
||||
" [ 7.000000, 9.000000, 11.000000, ],\n",
|
||||
" [ 13.000000, 15.000000, 17.000000, ],\n",
|
||||
" [ 19.000000, 21.000000, 23.000000, ]]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tensor_print_example1():\n",
|
||||
" shape = (4, 3, 2)\n",
|
||||
" \n",
|
||||
" # Creates [0,...,23] and reshape to (4, 3, 2)\n",
|
||||
" data = np.arange(24, dtype=np.float32).reshape(*shape) \n",
|
||||
" \n",
|
||||
" print_tensor_basic(from_dlpack(data))\n",
|
||||
"\n",
|
||||
"tensor_print_example1()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The verbosed print will show coodination details of each element in the tensor. The below example shows how we index element in a 2D 4x3 tensor space."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Verbose output:\n",
|
||||
"tensor(raw_ptr(0x000000000a814cc0: f32, generic, align<4>) o (4,3):(3,1), data= (\n",
|
||||
"\t(0,0)= 0.000000\n",
|
||||
"\t(0,1)= 1.000000\n",
|
||||
"\t(0,2)= 2.000000\n",
|
||||
"\t(1,0)= 3.000000\n",
|
||||
"\t(1,1)= 4.000000\n",
|
||||
"\t(1,2)= 5.000000\n",
|
||||
"\t(2,0)= 6.000000\n",
|
||||
"\t(2,1)= 7.000000\n",
|
||||
"\t(2,2)= 8.000000\n",
|
||||
"\t(3,0)= 9.000000\n",
|
||||
"\t(3,1)= 10.000000\n",
|
||||
"\t(3,2)= 11.000000\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tensor_print_example2():\n",
|
||||
" shape = (4, 3)\n",
|
||||
" \n",
|
||||
" # Creates [0,...,11] and reshape to (4, 3)\n",
|
||||
" data = np.arange(12, dtype=np.float32).reshape(*shape) \n",
|
||||
" \n",
|
||||
" print_tensor_verbose(from_dlpack(data))\n",
|
||||
"\n",
|
||||
"tensor_print_example2()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To print a subset elements in the given Tensor, we can use cute.slice_ to select a range of the given tensor, load them into register and then print the values with `cute.print_tensor`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Slice output:\n",
|
||||
"tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (4):(3), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [Slice output:\n",
|
||||
" 6.000000, ],\n",
|
||||
" [ 9.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 4.000000, ],\n",
|
||||
" [ 5.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tensor_print_example3():\n",
|
||||
" shape = (4, 3)\n",
|
||||
" \n",
|
||||
" # Creates [0,...,11] and reshape to (4, 3)\n",
|
||||
" data = np.arange(12, dtype=np.float32).reshape(*shape) \n",
|
||||
" \n",
|
||||
" print_tensor_slice(from_dlpack(data), (None, 0))\n",
|
||||
" print_tensor_slice(from_dlpack(data), (1, None))\n",
|
||||
"\n",
|
||||
"tensor_print_example3()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
390
examples/python/CuTeDSL/notebooks/tensor.ipynb
Normal file
390
examples/python/CuTeDSL/notebooks/tensor.ipynb
Normal file
@ -0,0 +1,390 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tensor\n",
|
||||
"\n",
|
||||
"A tensor in CuTe is created through the composition of two key components:\n",
|
||||
"\n",
|
||||
"1. An **Engine** (E) - A random-access, pointer-like object that supports:\n",
|
||||
" - Offset operation: `e + d → e` (offset engine by elements of a layout's codomain)\n",
|
||||
" - Dereference operation: `*e → v` (dereference engine to produce value)\n",
|
||||
"\n",
|
||||
"2. A **Layout** (L) - Defines the mapping from coordinates to offsets\n",
|
||||
"\n",
|
||||
"A tensor is formally defined as the composition of an engine E with a layout L, expressed as `T = E ∘ L`. When evaluating a tensor at coordinate c, it:\n",
|
||||
"\n",
|
||||
"1. Maps the coordinate c to the codomain using the layout\n",
|
||||
"2. Offsets the engine accordingly\n",
|
||||
"3. Dereferences the result to obtain the tensor's value\n",
|
||||
"\n",
|
||||
"This can be expressed mathematically as:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"T(c) = (E ∘ L)(c) = *(E + L(c))\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"## Example Usage\n",
|
||||
"\n",
|
||||
"Here's a simple example of creating a tensor using pointer and layout `(8,5):(5,1)` and fill with ones:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def create_tensor_from_ptr(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
" tensor.fill(1)\n",
|
||||
" cute.print_tensor(tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This creates a tensor where:\n",
|
||||
"- The engine is a pointer\n",
|
||||
"- The layout with shape `(8, 5)` and stride `(5, 1)`\n",
|
||||
"- The resulting tensor can be evaluated using coordinates defined by the layout\n",
|
||||
"\n",
|
||||
"We can test this by allocating buffer with torch and run test with pointer to torch tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x000000000736b0c0: f32, generic, align<4>) o (8,5):(5,1), data=\n",
|
||||
" [[ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" ...\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from cutlass.torch import dtype as torch_dtype\n",
|
||||
"import cutlass.cute.runtime as cute_rt\n",
|
||||
"\n",
|
||||
"a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n",
|
||||
"ptr_a = cute_rt.make_ptr(cutlass.Float32, a.data_ptr())\n",
|
||||
"\n",
|
||||
"create_tensor_from_ptr(ptr_a)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## DLPACK support \n",
|
||||
"\n",
|
||||
"CuTe DSL is designed to support dlpack protocol natively. This offers easy integration with frameworks \n",
|
||||
"supporting DLPack, e.g. torch, numpy, jax, tensorflow, etc.\n",
|
||||
"\n",
|
||||
"For more information, please refer to DLPACK project: https://github.com/dmlc/dlpack\n",
|
||||
"\n",
|
||||
"Calling `from_dlpack` can convert any tensor or ndarray object supporting `__dlpack__` and `__dlpack_device__`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass.cute.runtime import from_dlpack\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_dlpack(src: cute.Tensor):\n",
|
||||
" print(src)\n",
|
||||
" cute.print_tensor(src)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, generic> o (8,5):(5,1)>\n",
|
||||
"tensor(raw_ptr(0x0000000007559340: f32, generic, align<4>) o (8,5):(5,1), data=\n",
|
||||
" [[-1.151769, 1.019397, -0.371175, -0.717776, 0.502176, ],\n",
|
||||
" [ 0.114282, 0.900084, 0.320770, 1.564574, -0.632329, ],\n",
|
||||
" [-0.570140, 0.178112, -0.423079, 1.936198, 0.003355, ],\n",
|
||||
" ...\n",
|
||||
" [-2.425393, -0.275528, 1.267157, -0.811101, -0.985456, ],\n",
|
||||
" [ 0.777889, -2.114074, 0.357184, -0.321312, -0.938138, ],\n",
|
||||
" [ 1.959564, 1.797602, 0.116901, 0.306198, -1.837295, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n",
|
||||
"\n",
|
||||
"print_tensor_dlpack(from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, generic> o (8,8):(8,1)>\n",
|
||||
"tensor(raw_ptr(0x0000000007979da0: f32, generic, align<4>) o (8,8):(8,1), data=\n",
|
||||
" [[ 0.122739, -0.605744, -1.442022, ..., -0.356501, -0.993329, -0.091110, ],\n",
|
||||
" [ 0.278448, 0.318482, -0.276867, ..., 1.542181, -1.701539, -0.309454, ],\n",
|
||||
" [ 0.563565, -0.753936, 0.131214, ..., 0.437912, -0.482277, -0.051540, ],\n",
|
||||
" ...\n",
|
||||
" [-1.974096, -0.177881, 0.426807, ..., -1.579115, -0.304974, 0.451164, ],\n",
|
||||
" [ 0.149851, -0.704689, -0.295063, ..., -0.653001, 0.008871, 0.903916, ],\n",
|
||||
" [ 1.188619, 1.519662, 1.270734, ..., 0.404082, 0.173200, 0.093476, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"a = np.random.randn(8, 8).astype(np.float32)\n",
|
||||
"\n",
|
||||
"print_tensor_dlpack(from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tensor Evaluation Methods\n",
|
||||
"\n",
|
||||
"Tensors support two primary methods of evaluation:\n",
|
||||
"\n",
|
||||
"### 1. Full Evaluation\n",
|
||||
"When applying the tensor evaluation with a complete coordinate c, it computes the offset, applies it to the engine, \n",
|
||||
"and dereferences it to return the stored value. This is the straightforward case where you want to access \n",
|
||||
"a specific element of the tensor.\n",
|
||||
"\n",
|
||||
"### 2. Partial Evaluation (Slicing)\n",
|
||||
"When evaluating with an incomplete coordinate c = c' ⊕ c* (where c* represents the unspecified portion), \n",
|
||||
"the result is a new tensor which is a slice of the original tensor with its engine offset to account for \n",
|
||||
"the coordinates that were provided. This operation can be expressed as:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"T(c) = (E ∘ L)(c) = (E + L(c')) ∘ L(c*) = T'(c*)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Slicing effectively reduces the dimensionality of the tensor, creating a sub-tensor that can be \n",
|
||||
"further evaluated or manipulated."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a[2] = 10.000000 (equivalent to a[(2,0)])\n",
|
||||
"a[9] = 6.000000 (equivalent to a[(1,1)])\n",
|
||||
"a[2,0] = 10.000000\n",
|
||||
"a[2,4] = 14.000000\n",
|
||||
"a[(2,4)] = 14.000000\n",
|
||||
"a[2,3] = 100.000000\n",
|
||||
"a[(2,4)] = 101.000000\n",
|
||||
"tensor([[ 0., 1., 2., 3., 4.],\n",
|
||||
" [ 5., 6., 7., 8., 9.],\n",
|
||||
" [ 10., 11., 12., 100., 101.],\n",
|
||||
" [ 15., 16., 17., 18., 19.],\n",
|
||||
" [ 20., 21., 22., 23., 24.],\n",
|
||||
" [ 25., 26., 27., 28., 29.],\n",
|
||||
" [ 30., 31., 32., 33., 34.],\n",
|
||||
" [ 35., 36., 37., 38., 39.]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def tensor_access_item(a: cute.Tensor):\n",
|
||||
" # access data using linear index\n",
|
||||
" cute.printf(\"a[2] = {} (equivalent to a[{}])\", a[2],\n",
|
||||
" cute.make_identity_tensor(a.layout.shape)[2])\n",
|
||||
" cute.printf(\"a[9] = {} (equivalent to a[{}])\", a[9],\n",
|
||||
" cute.make_identity_tensor(a.layout.shape)[9])\n",
|
||||
"\n",
|
||||
" # access data using n-d coordinates, following two are equivalent\n",
|
||||
" cute.printf(\"a[2,0] = {}\", a[2, 0])\n",
|
||||
" cute.printf(\"a[2,4] = {}\", a[2, 4])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[2, 4])\n",
|
||||
"\n",
|
||||
" # assign value to tensor@(2,4)\n",
|
||||
" a[2,3] = 100.0\n",
|
||||
" a[2,4] = 101.0\n",
|
||||
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
|
||||
"\n",
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
"\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
"\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.print_tensor(tensor)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a tensor with sequential data using torch\n",
|
||||
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",
|
||||
"tensor_access_item(from_dlpack(data))\n",
|
||||
"\n",
|
||||
"print(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Tensor as memory view\n",
|
||||
"\n",
|
||||
"In CUDA programming, different memory spaces have different characteristics in terms of access speed, scope, and lifetime:\n",
|
||||
"\n",
|
||||
"- **generic**: Default memory space that can refer to any other memory space.\n",
|
||||
"- **global memory (gmem)**: Accessible by all threads across all blocks, but has higher latency.\n",
|
||||
"- **shared memory (smem)**: Accessible by all threads within a block, with much lower latency than global memory.\n",
|
||||
"- **register memory (rmem)**: Thread-private memory with the lowest latency, but limited capacity.\n",
|
||||
"- **tensor memory (tmem)**: Specialized memory introduced in NVIDIA Blackwell architecture for tensor operations.\n",
|
||||
"\n",
|
||||
"When creating tensors in CuTe, you can specify the memory space to optimize performance based on your access patterns.\n",
|
||||
"\n",
|
||||
"For more information on CUDA memory spaces, see the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Coordinate Tensor\n",
|
||||
"\n",
|
||||
"A coordinate tensor is a special type of tensor that maps coordinates to coordinates rather than to values. \n",
|
||||
"The key distinction is that while regular tensors map coordinates to some value type (like numbers), \n",
|
||||
"coordinate tensors map coordinates to other coordinates.\n",
|
||||
"\n",
|
||||
"For example, given a shape (4,4), a coordinate tensor using row-major layout would appear as:\n",
|
||||
"\n",
|
||||
"\\begin{bmatrix} \n",
|
||||
"(0,0) & (0,1) & (0,2) & (0,3) \\\\\n",
|
||||
"(1,0) & (1,1) & (1,2) & (1,3) \\\\\n",
|
||||
"(2,0) & (2,1) & (2,2) & (2,3) \\\\\n",
|
||||
"(3,0) & (3,1) & (3,2) & (3,3)\n",
|
||||
"\\end{bmatrix}\n",
|
||||
"\n",
|
||||
"The same shape with a column-major layout would appear as:\n",
|
||||
"\n",
|
||||
"\\begin{bmatrix}\n",
|
||||
"(0,0) & (1,0) & (2,0) & (3,0) \\\\\n",
|
||||
"(0,1) & (1,1) & (2,1) & (3,1) \\\\\n",
|
||||
"(0,2) & (1,2) & (2,2) & (3,2) \\\\\n",
|
||||
"(0,3) & (1,3) & (2,3) & (3,3)\n",
|
||||
"\\end{bmatrix}\n",
|
||||
"\n",
|
||||
"The key points about coordinate tensors are:\n",
|
||||
"- Each element in the tensor is itself a coordinate tuple (i,j) rather than a scalar value\n",
|
||||
"- The coordinates map to themselves - so position (1,2) contains the coordinate (1,2)\n",
|
||||
"- The layout (row-major vs column-major) determines how these coordinate tuples are arranged in memory\n",
|
||||
"\n",
|
||||
"For example, coordinate tensors can be created using the `make_identity_tensor` utility:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"coord_tensor = make_identity_tensor(layout.shape())\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This creates a tensor that maps each coordinate to itself, providing a reference point for understanding how other layouts transform these coordinates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<(0,0) o (8,4):(1@0,1@1)>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_coord(a: cute.Tensor):\n",
|
||||
" coord_tensor = cute.make_identity_tensor(a.layout.shape)\n",
|
||||
" print(coord_tensor)\n",
|
||||
"\n",
|
||||
"a = torch.randn(8,4, dtype=torch_dtype(cutlass.Float32))\n",
|
||||
"print_tensor_coord(from_dlpack(a))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"state": {},
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
558
examples/python/CuTeDSL/notebooks/tensorssa.ipynb
Normal file
558
examples/python/CuTeDSL/notebooks/tensorssa.ipynb
Normal file
@ -0,0 +1,558 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"from cutlass.cute.runtime import from_dlpack\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Introduction to the TensorSSA in CuTe DSL\n",
|
||||
"\n",
|
||||
"This tutorial introduces what is the `TensorSSA` and why we need it. We also give some examples to show how to use `TensorSSA`.\n",
|
||||
"\n",
|
||||
"## What is TensorSSA\n",
|
||||
"\n",
|
||||
"`TensorSSA` is a Python class that represents a tensor value in Static Single Assignment (SSA) form within the CuTe DSL. You can think of it as a tensor residing in a (simulated) register.\n",
|
||||
"\n",
|
||||
"## Why TensorSSA\n",
|
||||
"\n",
|
||||
"`TensorSSA` encapsulates the underlying MLIR tensor value into an object that's easier to manipulate in Python. By overloading numerous Python operators (like `+`, `-`, `*`, `/`, `[]`, etc.), it allows users to express tensor computations (primarily element-wise operations and reductions) in a more Pythonic way. These element-wise operations are then translated into optimized vectorization instructions.\n",
|
||||
"\n",
|
||||
"It's part of the CuTe DSL, serving as a bridge between the user-described computational logic and the lower-level MLIR IR, particularly for representing and manipulating register-level data.\n",
|
||||
"\n",
|
||||
"## When to use TensorSSA\n",
|
||||
"\n",
|
||||
"`TensorSSA` is primarily used in the following scenarios:\n",
|
||||
"\n",
|
||||
"### Load from memory and store to memory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
|
||||
"b_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
|
||||
"tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n",
|
||||
" [[ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
|
||||
" [ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
|
||||
" [ 2.000000, 2.000000, 2.000000, 2.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" \"\"\"\n",
|
||||
" Load data from memory and store the result to memory.\n",
|
||||
"\n",
|
||||
" :param res: The destination tensor to store the result.\n",
|
||||
" :param a: The source tensor to be loaded.\n",
|
||||
" :param b: The source tensor to be loaded.\n",
|
||||
" \"\"\"\n",
|
||||
" a_vec = a.load()\n",
|
||||
" print(f\"a_vec: {a_vec}\") # prints `a_vec: vector<12xf32> o (3, 4)`\n",
|
||||
" b_vec = b.load()\n",
|
||||
" print(f\"b_vec: {b_vec}\") # prints `b_vec: vector<12xf32> o (3, 4)`\n",
|
||||
" res.store(a_vec + b_vec)\n",
|
||||
" cute.print_tensor(res)\n",
|
||||
"\n",
|
||||
"a = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
|
||||
"b = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
|
||||
"c = np.zeros(12).reshape((3, 4)).astype(np.float32)\n",
|
||||
"load_and_store(from_dlpack(c), from_dlpack(a), from_dlpack(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Register-Level Tensor Operations\n",
|
||||
"\n",
|
||||
"When writing kernel logic, various computations, transformations, slicing, etc., are performed on data loaded into registers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>\n",
|
||||
"tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n",
|
||||
" [[ 3.000000, 4.000000, 5.000000, ],\n",
|
||||
" [ 9.000000, 10.000000, 11.000000, ],\n",
|
||||
" [ 15.000000, 16.000000, 17.000000, ],\n",
|
||||
" [ 21.000000, 22.000000, 23.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n",
|
||||
" \"\"\"\n",
|
||||
" Apply slice operation on the src tensor and store the result to the dst tensor.\n",
|
||||
"\n",
|
||||
" :param src: The source tensor to be sliced.\n",
|
||||
" :param dst: The destination tensor to store the result.\n",
|
||||
" :param indices: The indices to slice the source tensor.\n",
|
||||
" \"\"\"\n",
|
||||
" src_vec = src.load()\n",
|
||||
" dst_vec = src_vec[indices]\n",
|
||||
" print(f\"{src_vec} -> {dst_vec}\")\n",
|
||||
" if isinstance(dst_vec, cute.TensorSSA):\n",
|
||||
" dst.store(dst_vec)\n",
|
||||
" cute.print_tensor(dst)\n",
|
||||
" else:\n",
|
||||
" dst[0] = dst_vec\n",
|
||||
" cute.print_tensor(dst)\n",
|
||||
"\n",
|
||||
"def slice_1():\n",
|
||||
" src_shape = (4, 2, 3)\n",
|
||||
" dst_shape = (4, 3)\n",
|
||||
" indices = (None, 1, None)\n",
|
||||
"\n",
|
||||
" \"\"\"\n",
|
||||
" a:\n",
|
||||
" [[[ 0. 1. 2.]\n",
|
||||
" [ 3. 4. 5.]]\n",
|
||||
"\n",
|
||||
" [[ 6. 7. 8.]\n",
|
||||
" [ 9. 10. 11.]]\n",
|
||||
"\n",
|
||||
" [[12. 13. 14.]\n",
|
||||
" [15. 16. 17.]]\n",
|
||||
"\n",
|
||||
" [[18. 19. 20.]\n",
|
||||
" [21. 22. 23.]]]\n",
|
||||
" \"\"\"\n",
|
||||
" a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
|
||||
" dst = np.random.randn(*dst_shape).astype(np.float32)\n",
|
||||
" apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
|
||||
"\n",
|
||||
"slice_1()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor_value<vector<24xf32> o (4, 2, 3)> -> ?\n",
|
||||
"tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n",
|
||||
" [ 10.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def slice_2():\n",
|
||||
" src_shape = (4, 2, 3)\n",
|
||||
" dst_shape = (1,)\n",
|
||||
" indices = 10\n",
|
||||
" a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
|
||||
" dst = np.random.randn(*dst_shape).astype(np.float32)\n",
|
||||
" apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
|
||||
"\n",
|
||||
"slice_2()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Arithmetic Operations\n",
|
||||
"\n",
|
||||
"As we mentioned earlier, there're many tensor operations whose operands are `TensorSSA`. And they are all element-wise operations. We give some examples below.\n",
|
||||
"\n",
|
||||
"### Binary Operations\n",
|
||||
"\n",
|
||||
"For binary operations, the LHS operand is `TensorSSA` and the RHS operand can be either `TensorSSA` or `Numeric`. When the RHS is `Numeric`, it will be broadcast to a `TensorSSA`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + b_vec\n",
|
||||
" res.store(add_res)\n",
|
||||
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
"\n",
|
||||
" sub_res = a_vec - b_vec\n",
|
||||
" res.store(sub_res)\n",
|
||||
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
"\n",
|
||||
" mul_res = a_vec * b_vec\n",
|
||||
" res.store(mul_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" div_res = a_vec / b_vec\n",
|
||||
" res.store(div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // b_vec\n",
|
||||
" res.store(floor_div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % b_vec\n",
|
||||
" res.store(mod_res)\n",
|
||||
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
"a.fill(1.0)\n",
|
||||
"b = np.empty((3,), dtype=np.float32)\n",
|
||||
"b.fill(2.0)\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"binary_op_1(from_dlpack(res), from_dlpack(a), from_dlpack(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + c\n",
|
||||
" res.store(add_res)\n",
|
||||
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
"\n",
|
||||
" sub_res = a_vec - c\n",
|
||||
" res.store(sub_res)\n",
|
||||
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
"\n",
|
||||
" mul_res = a_vec * c\n",
|
||||
" res.store(mul_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" div_res = a_vec / c\n",
|
||||
" res.store(div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // c\n",
|
||||
" res.store(floor_div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % c\n",
|
||||
" res.store(mod_res)\n",
|
||||
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
"\n",
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
"a.fill(1.0)\n",
|
||||
"c = 2.0\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"binary_op_2(from_dlpack(res), from_dlpack(a), c)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[False True False]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" gt_res = a_vec > b_vec\n",
|
||||
" res.store(gt_res)\n",
|
||||
"\n",
|
||||
" \"\"\"\n",
|
||||
" ge_res = a_ >= b_ # [False, True, False]\n",
|
||||
" lt_res = a_ < b_ # [True, False, True]\n",
|
||||
" le_res = a_ <= b_ # [True, False, True]\n",
|
||||
" eq_res = a_ == b_ # [False, False, False]\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
"a = np.array([1, 2, 3], dtype=np.float32)\n",
|
||||
"b = np.array([2, 1, 4], dtype=np.float32)\n",
|
||||
"res = np.empty((3,), dtype=np.bool_)\n",
|
||||
"binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
|
||||
"print(res) # prints [False, True, False]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[3 0 7]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" xor_res = a_vec ^ b_vec\n",
|
||||
" res.store(xor_res)\n",
|
||||
"\n",
|
||||
" # or_res = a_vec | b_vec\n",
|
||||
" # res.store(or_res) # prints [3, 2, 7]\n",
|
||||
"\n",
|
||||
" # and_res = a_vec & b_vec\n",
|
||||
" # res.store(and_res) # prints [0, 2, 0]\n",
|
||||
"\n",
|
||||
"a = np.array([1, 2, 3], dtype=np.int32)\n",
|
||||
"b = np.array([2, 2, 4], dtype=np.int32)\n",
|
||||
"res = np.empty((3,), dtype=np.int32)\n",
|
||||
"binary_op_4(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
|
||||
"print(res) # prints [3, 0, 7]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Unary Operations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-0.756802, ],\n",
|
||||
" [-0.756802, ],\n",
|
||||
" [-0.756802, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 16.000000, ],\n",
|
||||
" [ 16.000000, ],\n",
|
||||
" [ 16.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" sqrt_res = cute.math.sqrt(a_vec)\n",
|
||||
" res.store(sqrt_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" sin_res = cute.math.sin(a_vec)\n",
|
||||
" res.store(sin_res)\n",
|
||||
" cute.print_tensor(res) # prints [-0.756802, -0.756802, -0.756802]\n",
|
||||
"\n",
|
||||
" exp2_res = cute.math.exp2(a_vec)\n",
|
||||
" res.store(exp2_res)\n",
|
||||
" cute.print_tensor(res) # prints [16.000000, 16.000000, 16.000000]\n",
|
||||
"\n",
|
||||
"a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"unary_op_1(from_dlpack(res), from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Reduction Operation\n",
|
||||
"\n",
|
||||
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"21.000000\n",
|
||||
"tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n",
|
||||
" [ 6.000000, ],\n",
|
||||
" [ 15.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n",
|
||||
" [ 6.000000, ],\n",
|
||||
" [ 8.000000, ],\n",
|
||||
" [ 10.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def reduction_op(a: cute.Tensor):\n",
|
||||
" \"\"\"\n",
|
||||
" Apply reduction operation on the src tensor.\n",
|
||||
"\n",
|
||||
" :param src: The source tensor to be reduced.\n",
|
||||
" \"\"\"\n",
|
||||
" a_vec = a.load()\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 0.0,\n",
|
||||
" reduction_profile=0\n",
|
||||
" )\n",
|
||||
" cute.printf(red_res) # prints 21.000000\n",
|
||||
"\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 0.0,\n",
|
||||
" reduction_profile=(None, 1)\n",
|
||||
" )\n",
|
||||
" # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n",
|
||||
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
|
||||
" res.store(red_res)\n",
|
||||
" cute.print_tensor(res) # prints [6.000000, 15.000000]\n",
|
||||
"\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 1.0,\n",
|
||||
" reduction_profile=(1, None)\n",
|
||||
" )\n",
|
||||
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
|
||||
" res.store(red_res)\n",
|
||||
" cute.print_tensor(res) # prints [6.000000, 8.000000, 10.000000]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n",
|
||||
"reduction_op(from_dlpack(a))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user