v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@ -1,25 +1,31 @@
# Changelog # Changelog
# CUTLASS 4.x # CUTLASS 4.x
## [4.0.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-05-09)
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
### CuTe DSL ### CuTe DSL
* CuTe DSL, a Python DSL centered around CuTe's abstractions * CuTe DSL, a Python DSL centered around CuTe's abstractions
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
- [DSL quick start](./media/docs/pythonDSL/quick_start.rst) - [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](./media/docs/pythonDSL/overview.rst) - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass) * [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels * Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) - [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py) - [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py) - [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py) - [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/jit_argument.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
* API updates
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
### CUTLASS C++ ### CUTLASS C++
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9 * Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. - 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names. * Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
- For example: - For example:
@ -30,9 +36,25 @@
- Added non-power-of-two tile sizes. - Added non-power-of-two tile sizes.
- Improved performance for K-major scale factors. - Improved performance for K-major scale factors.
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions. - The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
* Support LSE output in Blackwell FMHA Forward kernel in example 77.
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
- Enable runtime datatype for Blackwell grouped GEMM. Profiler support is also added.
- Enable kernel parameter exploration for Blackwell grouped GEMM - raster_order, swizzle.
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
* Add dynamic and preferred cluster support for convolution kernels.
* Support for Blackwell SM120 blockwise dense gemm in cutlass core library, as well as cutlass profiler.
* Fix profiler issues which cause no output or not supported error for some kernels.
* Optimization porting for BlockScaled collectives and kernel layers.
* New [Hopper FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
* CuTe changes:
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.9. * Optimal code generation with CUDA toolkit versions 12.9.
# CUTLASS 3.x # CUTLASS 3.x
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03) ## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
@ -82,7 +104,7 @@
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). - More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
* Support `void` as the D element in sm100 kernel epilogues. * Support `void` as the D element in sm100 kernel epilogues.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.8U1. * Optimal code generation with CUDA toolkit versions 12.8U1.
@ -101,7 +123,7 @@
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp). - [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API: * Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that - [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
@ -139,11 +161,11 @@
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. - A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). - A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
* Documentation updates: * Documentation updates:
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel). - [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md) - Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
- A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#target-architecture). - Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
- Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper. - Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). - [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
@ -156,7 +178,7 @@
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication. + Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`. + Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel. + A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler). - New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
- Optimal code generation with CUDA toolkit versions 12.6. - Optimal code generation with CUDA toolkit versions 12.6.
@ -168,14 +190,14 @@
+ [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
+ [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) + [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/cpp/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). - [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md). - [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. - [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. - A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). - A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. - A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper). - [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h) - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu). - Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
@ -198,7 +220,7 @@
- Support for residual add (beta != 0) in convolution kernels. - Support for residual add (beta != 0) in convolution kernels.
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt). - A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md). - [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
- Better support for MSVC as a host compiler. - Better support for MSVC as a host compiler.
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. - Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. - Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
@ -206,13 +228,13 @@
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) ## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md). + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. - Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/cpp/README.md) in CuTe and CUTLASS 3.x - [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. - 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
@ -279,7 +301,7 @@
* Updates and bugfixes from the community (thanks!) * Updates and bugfixes from the community (thanks!)
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14) ## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/cpp/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python). * New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python).
* New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper. * New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues. * Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. * New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.

View File

@ -175,11 +175,13 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
endif() endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
endif() endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9) if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 101f 120f) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
endif() endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
@ -344,6 +346,10 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
endif() endif()
if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
endif()
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace") set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
# #

View File

@ -1056,7 +1056,7 @@ HTML_STYLESHEET =
# defined cascading style sheet that is included after the standard style sheets # defined cascading style sheet that is included after the standard style sheets
# created by doxygen. Using this option one can overrule certain style aspects. # created by doxygen. Using this option one can overrule certain style aspects.
# This is preferred over using HTML_STYLESHEET since it does not replace the # This is preferred over using HTML_STYLESHEET since it does not replace the
# standard style sheet and is therefore more robust against future updates. # standard style sheet and is therefor more robust against future updates.
# Doxygen will copy the style sheet file to the output directory. For an example # Doxygen will copy the style sheet file to the output directory. For an example
# see the documentation. # see the documentation.
# This tag requires that the tag GENERATE_HTML is set to YES. # This tag requires that the tag GENERATE_HTML is set to YES.
@ -1940,7 +1940,7 @@ PREDEFINED =
EXPAND_AS_DEFINED = EXPAND_AS_DEFINED =
# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will # If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
# remove all references to function-like macros that are alone on a line, have an # remove all refrences to function-like macros that are alone on a line, have an
# all uppercase name, and do not end with a semicolon. Such function macros are # all uppercase name, and do not end with a semicolon. Such function macros are
# typically used for boiler-plate code, and will confuse the parser if not # typically used for boiler-plate code, and will confuse the parser if not
# removed. # removed.

View File

@ -40,21 +40,22 @@ designs, and bringing optimized solutions into production.
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025. CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
To get started quickly - please refer : To get started quickly - please refer :
- [CUTLASS C++ Quick Start Guide](./media/docs/cpp/quickstart.md). - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
- [CuTe DSL Quick Start Guide](./media/docs/pythonDSL/quick_start.rst). - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
# What's New in CUTLASS 4.0 # What's New in CUTLASS 4.0
## CuTe DSL ## CuTe DSL
* CuTe DSL, a Python DSL centered around CuTe's abstractions * CuTe DSL, a Python DSL centered around CuTe's abstractions
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
- [DSL Quick Start](./media/docs/pythonDSL/quick_start.rst) - [DSL Quick Start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
- [DSL Overview](./media/docs/pythonDSL/overview.rst) - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass) * [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels * Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) - [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py) - [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py) - [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py) - [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
@ -71,13 +72,15 @@ To get started quickly - please refer :
- Added non-power-of-two tile sizes. - Added non-power-of-two tile sizes.
- Improved performance for K-major scale factors. - Improved performance for K-major scale factors.
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions. - The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
* Support LSE output in Blackwell FMHA Forward kernel.
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
* Optimal code generation with CUDA toolkit versions 12.9. * Optimal code generation with CUDA toolkit versions 12.9.
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
CUTLASS team is working on a fix. CUTLASS team is working on a fix.
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.** **See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
# Performance # Performance
@ -119,7 +122,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
This greatly simplifies the design and improves code composability and readability. This greatly simplifies the design and improves code composability and readability.
More documentation specific to CuTe can be found in its More documentation specific to CuTe can be found in its
[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md). [dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
# Compatibility # Compatibility
@ -202,7 +205,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
compiled for Blackwell SM100 architecture with arch conditional features compiled for Blackwell SM100 architecture with arch conditional features
(using `sm100a`) are not compatible with RTX 50 series GPUs. (using `sm100a`) are not compatible with RTX 50 series GPUs.
Please refer to the [functionality documentation](./media/docs/cpp/functionality.md) Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
for details on which kernels require which target architectures. for details on which kernels require which target architectures.
# Documentation # Documentation
@ -210,22 +213,22 @@ for details on which kernels require which target architectures.
CUTLASS is described in the following documents and the accompanying CUTLASS is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass). [Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS - [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS - [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA - [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components - [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts - [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts - [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS - [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project - [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code - [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ - [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays - [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory - [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory - [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application - [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development - [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent - [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
kernels in the same stream, and how it is used in CUTLASS. kernels in the same stream, and how it is used in CUTLASS.
# Resources # Resources
@ -245,7 +248,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
paths. paths.
CUTLASS unit tests, examples, and utilities can be build with CMake. CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md). The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system. on your system.
@ -290,7 +293,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
and template concepts defined in the CUTLASS project. and template concepts defined in the CUTLASS project.
A detailed explanation of the source code organization may be found in the A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below. [CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
## CUTLASS Template Library ## CUTLASS Template Library
@ -364,7 +367,7 @@ tools/
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations. basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md). Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
# Performance Profiling # Performance Profiling
@ -580,9 +583,9 @@ reference_device: Passed
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples) - [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples) - [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md) - [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
# About # About

View File

@ -42,7 +42,6 @@
#include "cute/algorithm/functional.hpp" #include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp" #include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp" #include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/arch/grid_dependency_control.h" #include "cutlass/arch/grid_dependency_control.h"
@ -288,7 +287,7 @@ struct CollectiveMma<
constexpr int tma_alignment_bits = 128; constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1); auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL; auto [M,N,K,L] = problem_shape_MNKL;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value; constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{}); bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value; constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
@ -445,7 +444,7 @@ struct CollectiveMma<
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter; ++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true; launch_dep_grids = true;
cutlass::arch::launch_dependent_grids(); cutlass::arch::launch_dependent_grids();
} }
@ -453,7 +452,7 @@ struct CollectiveMma<
// Advance smem_pipe_write // Advance smem_pipe_write
++smem_pipe_write; ++smem_pipe_write;
} }
if (!disable_gdc && !launch_dep_grids) { if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids(); cutlass::arch::launch_dependent_grids();
} }
} }
@ -533,7 +532,7 @@ struct CollectiveMma<
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
++k_tile_iter; ++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true; launch_dep_grids = true;
cutlass::arch::launch_dependent_grids(); cutlass::arch::launch_dependent_grids();
} }
@ -541,7 +540,7 @@ struct CollectiveMma<
// Advance smem_pipe_write // Advance smem_pipe_write
++smem_pipe_write; ++smem_pipe_write;
} }
if (!disable_gdc && !launch_dep_grids) { if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids(); cutlass::arch::launch_dependent_grids();
} }
} }
@ -634,9 +633,9 @@ struct CollectiveMma<
// Issue the epilogue waits // Issue the epilogue waits
if (lane_predicate) { if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster /* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all * Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used * Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was * then would just be acquired since the phase was
* still inverted from make_producer_start_state * still inverted from make_producer_start_state
*/ */
pipeline.producer_tail(smem_pipe_write); pipeline.producer_tail(smem_pipe_write);
@ -854,7 +853,7 @@ struct CollectiveMma<
k_tile_count -= prologue_mma_count; k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count); smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete // Wait on all GMMAs to complete
warpgroup_wait<0>(); warpgroup_wait<0>();

View File

@ -133,7 +133,7 @@ using TP = _8;
static constexpr int TP_ = TP{}; static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \ #if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
// Distributed GEMM tiling/sharding schedule // Distributed GEMM tiling/sharding schedule
// Choices: // Choices:
@ -252,7 +252,8 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_]; HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_]; HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types /// Testbed utility types
@ -345,7 +346,7 @@ struct Result {
}; };
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \ #if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation /// GEMM setup and evaluation
@ -803,17 +804,18 @@ int run(Options &options) {
return 0; return 0;
} }
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) { int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example // CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example
// and must have compute capability at least 90. // and must have compute capability at least 90.
// Some necessary cuda graph APIs were only introduced in CUDA 12.4. // Some necessary cuda graph APIs were only introduced in CUDA 12.6.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl; std::cerr << "This example requires CUDA 12.6 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op. // Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0; return 0;
} }
@ -857,11 +859,11 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels // Evaluate CUTLASS kernels
// //
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
run(options); run(options);
#else #else
std::cerr std::cerr
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.4 or later." << std::endl; << "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl;
return 0; return 0;
#endif #endif

View File

@ -250,8 +250,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types /// Testbed utility types
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
/// Result structure /// Result structure
struct Result struct Result
{ {
@ -518,7 +516,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
} }
arguments.scheduler.raster_order = options.raster; arguments.scheduler.raster_order = options.raster_order;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle; arguments.scheduler.max_swizzle_size = options.swizzle;
@ -690,10 +688,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
std::string raster = "Heuristic"; std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) { if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N"; raster = "Along N";
} }
else if (options.raster == RasterOrderOptions::AlongM) { else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M"; raster = "Along M";
} }
@ -747,7 +745,7 @@ int main(int argc, char const **args) {
// Parse options // Parse options
// //
Options<RasterOrderOptions, ProblemShape> options; Options<ProblemShape> options;
options.parse(argc, args); options.parse(argc, args);

View File

@ -253,8 +253,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types /// Testbed utility types
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
/// Result structure /// Result structure
struct Result struct Result
{ {
@ -523,7 +521,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
} }
arguments.scheduler.raster_order = options.raster; arguments.scheduler.raster_order = options.raster_order;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle; arguments.scheduler.max_swizzle_size = options.swizzle;
@ -699,10 +697,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
std::string raster = "Heuristic"; std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) { if (options.raster_order == RasterOrderOptions::AlongN) {
raster = "Along N"; raster = "Along N";
} }
else if (options.raster == RasterOrderOptions::AlongM) { else if (options.raster_order == RasterOrderOptions::AlongM) {
raster = "Along M"; raster = "Along M";
} }
@ -755,7 +753,7 @@ int main(int argc, char const **args) {
// Parse options // Parse options
// //
Options<RasterOrderOptions, ProblemShape> options; Options<ProblemShape> options;
options.parse(argc, args); options.parse(argc, args);

View File

@ -30,10 +30,9 @@
**************************************************************************************************/ **************************************************************************************************/
// Command line options parsing // Command line options parsing
template<typename _RasterOrderOptions, typename _ProblemShape> using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
template<typename _ProblemShape>
struct Options { struct Options {
using RasterOrderOptions = _RasterOrderOptions;
using ProblemShape = _ProblemShape; using ProblemShape = _ProblemShape;
bool help = false; bool help = false;
@ -50,7 +49,7 @@ struct Options {
int const m_alignment = 128; int const m_alignment = 128;
int const n_alignment = 128; int const n_alignment = 128;
RasterOrderOptions raster; RasterOrderOptions raster_order;
int swizzle; int swizzle;
// Parses the command line // Parses the command line
@ -74,13 +73,13 @@ struct Options {
cmd.get_cmd_line_argument("raster", raster_char); cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') { if (raster_char == 'N' || raster_char == 'n') {
raster = RasterOrderOptions::AlongN; raster_order = RasterOrderOptions::AlongN;
} }
else if (raster_char == 'M' || raster_char == 'm') { else if (raster_char == 'M' || raster_char == 'm') {
raster = RasterOrderOptions::AlongM; raster_order = RasterOrderOptions::AlongM;
} }
else if (raster_char == 'H' || raster_char == 'h') { else if (raster_char == 'H' || raster_char == 'h') {
raster = RasterOrderOptions::Heuristic; raster_order = RasterOrderOptions::Heuristic;
} }
cmd.get_cmd_line_argument("swizzle", swizzle, 1); cmd.get_cmd_line_argument("swizzle", swizzle, 1);

View File

@ -543,7 +543,7 @@ int run(Options &options) {
int main(int argc, char const **args) { int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example // CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example
// and must have compute capability at least 100. // and must have compute capability at least 100.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -560,7 +560,6 @@ int main(int argc, char const **args) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0; return 0;
} }
// //
// Parse options // Parse options
// //

View File

@ -237,7 +237,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
/// Testbed utility types /// Testbed utility types
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing // Command line options parsing
struct Options { struct Options {

View File

@ -300,7 +300,7 @@ auto make_iterator(T* ptr) {
/// Testbed utility types /// Testbed utility types
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing // Command line options parsing
struct Options { struct Options {

View File

@ -490,7 +490,7 @@ int run(Options &options)
int main(int argc, char const **args) { int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 90. // and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id)); CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0); cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0; return 0;
} }
// //
// Parse options // Parse options
// //

View File

@ -490,7 +490,7 @@ int run(Options &options)
int main(int argc, char const **args) { int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
// and must have compute capability at least 90. // and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id)); CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0); cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0; return 0;
} }
// //
// Parse options // Parse options
// //

View File

@ -499,11 +499,11 @@ int main(int argc, char const **args) {
CUDA_CHECK(cudaGetDevice(&current_device_id)); CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0); cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
return 0; return 0;
} }
// //
// Parse options // Parse options
// //

View File

@ -117,15 +117,17 @@ struct Options {
int q = 256; int q = 256;
int k = 256; int k = 256;
int d = 128; int d = 128;
int warmup_iterations = 1;
int iterations = 3; int iterations = 3;
int tensor_ring_buffers = 1;
bool verify = false; bool verify = false;
bool verbose = false; bool verbose = false;
bool causal = false; bool causal = false;
bool residual = false; bool residual = false;
bool varlen = false; bool varlen = false;
bool persistent = false;
int sm_count = 0; int sm_count = 0;
std::string kernel_filter; std::string kernel_filter;
InitStyle init_style_q = InitStyle::kRandom; InitStyle init_style_q = InitStyle::kRandom;
@ -189,10 +191,15 @@ struct Options {
if (b == -1) b = 16384 / k; if (b == -1) b = 16384 / k;
if (b == 0) b = 1; if (b == 0) b = 1;
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
verify = cmd.check_cmd_line_flag("verify"); verify = cmd.check_cmd_line_flag("verify");
verbose = cmd.check_cmd_line_flag("verbose"); verbose = cmd.check_cmd_line_flag("verbose");
varlen = cmd.check_cmd_line_flag("varlen"); varlen = cmd.check_cmd_line_flag("varlen");
persistent = cmd.check_cmd_line_flag("persistent");
std::string mask; std::string mask;
cmd.get_cmd_line_argument<std::string>("mask", mask, ""); cmd.get_cmd_line_argument<std::string>("mask", mask, "");
if (mask == "no" || mask == "") { if (mask == "no" || mask == "") {
@ -210,7 +217,6 @@ struct Options {
causal = false; causal = false;
} }
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
@ -235,10 +241,13 @@ struct Options {
<< " --q=<int> Sets the Q extent\n" << " --q=<int> Sets the Q extent\n"
<< " --k=<int> Sets the K extent\n" << " --k=<int> Sets the K extent\n"
<< " --d=<int> Sets the D extentn" << " --d=<int> Sets the D extentn"
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
<< " --iterations=<int> Benchmarking iterations\n" << " --iterations=<int> Benchmarking iterations\n"
<< " --verify Verify results\n" << " --verify Verify results\n"
<< " --verbose Print smem and execution time per kernel\n" << " --verbose Print smem and execution time per kernel\n"
<< " --mask=<no|residual|causal> Enables masking\n" << " --mask=<no|residual|causal> Enables masking\n"
<< " --persistent Enables persistent scheduler\n"
<< " --varlen Enables variable sequence length\n" << " --varlen Enables variable sequence length\n"
<< " B*Q and B*K become the total sequence length\n" << " B*Q and B*K become the total sequence length\n"
<< " and are split B-ways, alternatingly +10% and -10%\n" << " and are split B-ways, alternatingly +10% and -10%\n"
@ -379,40 +388,55 @@ struct FwdRunner {
StrideLSE stride_LSE; StrideLSE stride_LSE;
uint64_t seed = 0; uint64_t seed = 0;
DeviceAllocation<Element> block_Q; struct DeviceBuffer {
DeviceAllocation<Element> block_K; DeviceAllocation<Element> block_Q;
DeviceAllocation<Element> block_V; DeviceAllocation<Element> block_K;
DeviceAllocation<ElementOut> block_O; DeviceAllocation<Element> block_V;
DeviceAllocation<ElementAccumulatorPV> block_LSE; DeviceAllocation<ElementOut> block_O;
DeviceAllocation<ElementOut> block_ref_O; DeviceAllocation<ElementAccumulatorPV> block_LSE;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE; DeviceAllocation<ElementOut> block_ref_O;
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
DeviceBuffer() = default;
DeviceBuffer(const DeviceBuffer&) = delete;
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
size_t get_storage_size() const {
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
+ device_cumulative_seqlen_kv.get_storage_size();
}
};
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
std::vector<int> cumulative_seqlen_q; std::vector<int> cumulative_seqlen_q;
std::vector<int> cumulative_seqlen_kv; std::vector<int> cumulative_seqlen_kv;
DeviceAllocation<int> device_cumulative_seqlen_q;
DeviceAllocation<int> device_cumulative_seqlen_kv;
// //
// Methods // Methods
// //
bool verify(const ProblemShapeType& problem_shape) { bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
select<0,2,3>(problem_shape), select<0,2,3>(problem_shape),
stride_Q); stride_Q);
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
select<1,2,3>(problem_shape), select<1,2,3>(problem_shape),
stride_K); stride_K);
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
select<1,2,3>(problem_shape), select<1,2,3>(problem_shape),
stride_V); stride_V);
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
select<0,2,3>(problem_shape), select<0,2,3>(problem_shape),
stride_O); stride_O);
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
select<0,3>(problem_shape), select<0,3>(problem_shape),
stride_LSE); stride_LSE);
@ -431,7 +455,7 @@ struct FwdRunner {
// Check if output from CUTLASS kernel and reference kernel are equal or not // Check if output from CUTLASS kernel and reference kernel are equal or not
double max_diff = 0; double max_diff = 0;
double mean_diff = 0; double mean_diff = 0;
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
if (! passed_O) { if (! passed_O) {
@ -439,14 +463,13 @@ struct FwdRunner {
<< " mean " << mean_diff << std::endl; << " mean " << mean_diff << std::endl;
} }
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
bool passed_LSE = true; // future work bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if ( ! passed_LSE) {
// if ( ! passed_LSE) { std::cerr << "failed LSE: max diff " << max_diff
// std::cerr << "failed LSE: max diff " << max_diff << " mean " << mean_diff << std::endl;
// << " mean " << mean_diff << std::endl; }
// }
return passed_O && passed_LSE; return passed_O && passed_LSE;
} }
@ -559,50 +582,70 @@ struct FwdRunner {
get<1,1>(stride_LSE) = 0; get<1,1>(stride_LSE) = 0;
} }
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); auto buffer_init_fn = [&](auto& buffer) {
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
block_LSE.reset(size(shape_LSE)); buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
block_ref_O.reset(size(shape_QO)); buffer.block_LSE.reset(size(shape_LSE));
block_ref_LSE.reset(size(shape_LSE));
initialize_block(block_Q, seed + 2023, options.init_style_q); initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
initialize_block(block_K, seed + 2022, options.init_style_k); initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
initialize_block(block_V, seed + 2021, options.init_style_v); initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
if ( ! cumulative_seqlen_q.empty()) { if ( ! cumulative_seqlen_q.empty()) {
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
device_cumulative_seqlen_q.copy_from_host( buffer.device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
} }
if ( ! cumulative_seqlen_kv.empty()) { if ( ! cumulative_seqlen_kv.empty()) {
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
device_cumulative_seqlen_kv.copy_from_host( buffer.device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
};
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
int tensor_ring_buffers = options.tensor_ring_buffers;
for (int i = 1; i < tensor_ring_buffers; i++) {
buffers.push_back(std::make_unique<DeviceBuffer>());
buffer_init_fn(*buffers.back());
} }
if constexpr (kIsVarlen) { if constexpr (kIsVarlen) {
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get(); get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get(); get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
} }
return problem_shape; return problem_shape;
} }
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
}
typename Operation::Arguments arguments{
problem_shape_,
{ buffers[buffer_index]->block_Q.get(), stride_Q,
buffers[buffer_index]->block_K.get(), stride_K,
buffers[buffer_index]->block_V.get(), stride_V },
{ buffers[buffer_index]->block_O.get(), stride_O,
buffers[buffer_index]->block_LSE.get(), stride_LSE },
hw_info
};
return arguments;
}
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_shape = initialize(options); ProblemShapeType problem_shape = initialize(options);
typename Operation::Arguments arguments{ int buffer_index = 0;
problem_shape, typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
{ block_Q.get(), stride_Q,
block_K.get(), stride_K,
block_V.get(), stride_V },
{ block_O.get(), stride_O,
block_LSE.get(), stride_LSE },
hw_info
};
Operation op; Operation op;
@ -630,11 +673,21 @@ struct FwdRunner {
} }
// Run // Run
status = op.run(); for (int i = 0; i < options.warmup_iterations; i++) {
if (status != cutlass::Status::kSuccess) { status = op.run();
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " if (status != cutlass::Status::kSuccess) {
<< cudaGetErrorString(cudaGetLastError()) << std::endl; std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
return example_result; << cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result;
}
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
} }
cudaError_t result = cudaDeviceSynchronize(); cudaError_t result = cudaDeviceSynchronize();
@ -672,6 +725,14 @@ struct FwdRunner {
<< cudaGetErrorString(cudaGetLastError()) << std::endl; << cudaGetErrorString(cudaGetLastError()) << std::endl;
return example_result; return example_result;
} }
buffer_index = (buffer_index + 1) % buffers.size();
arguments = get_arguments(problem_shape, hw_info, buffer_index);
status = op.update(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
<< std::endl;
return example_result;
}
} }
// //
@ -734,10 +795,10 @@ struct FwdRunner {
// Verify that the result is correct // Verify that the result is correct
bool passed = true; bool passed = true;
if (options.verify) { if (options.verify) {
passed = verify(problem_shape); passed = verify(problem_shape, *buffers[0]);
if (passed) example_result.verified = true; if (passed) example_result.verified = true;
} }
if (!passed) { if (!passed) {
std::cerr << "Reference check failed" << std::endl; std::cerr << "Reference check failed" << std::endl;
return example_result; return example_result;
@ -789,10 +850,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
using HeadDim = _128; using HeadDim = _128;
// Persistent Tile Scheduler if (options.persistent) {
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{}); // Persistent Tile Scheduler
// Individual Tile Scheduler run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{}); }
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
} }
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
@ -818,10 +883,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _64; using HeadDim = _64;
// Persistent Tile Scheduler if (options.persistent) {
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{}); // Persistent Tile Scheduler
// Individual Tile Scheduler run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{}); }
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
} }
@ -845,10 +914,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
using HeadDim = _32; using HeadDim = _32;
#ifdef FP8 #ifdef FP8
// Persistent Tile Scheduler if (options.persistent) {
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{}); // Persistent Tile Scheduler
// Individual Tile Scheduler run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{}); }
else {
// Individual Tile Scheduler
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
}
#endif #endif
} }

View File

@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
enum class InitStyle { enum class InitStyle {
kOne, kLinearStride128, kLinearStride1, kRandom, kNone kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
}; };
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
@ -98,6 +98,9 @@ struct Options {
if (s == "r") { if (s == "r") {
dst = InitStyle::kRandom; dst = InitStyle::kRandom;
} }
else if (s == "l") {
dst = InitStyle::kRandomLarge;
}
else if (s == "1") { else if (s == "1") {
dst = InitStyle::kOne; dst = InitStyle::kOne;
} }
@ -203,6 +206,11 @@ void initialize_block(
block.get(), block.size(), seed, (Element) -1, (Element) 1); block.get(), block.size(), seed, (Element) -1, (Element) 1);
break; break;
} }
case InitStyle::kRandomLarge: {
cutlass::reference::device::BlockFillRandomGaussian(
block.get(), block.size(), seed, (Element) -1, (Element) 100);
break;
}
case InitStyle::kLinearStride1: { case InitStyle::kLinearStride1: {
std::vector<Element> data(block.size()); std::vector<Element> data(block.size());
for (size_t i = 0; i < block.size() / 128; i ++) { for (size_t i = 0; i < block.size() / 128; i ++) {

View File

@ -144,4 +144,23 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC) target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v) target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
endforeach() endforeach()
# Add a target that builds all examples
add_custom_target(77_blackwell_fmha_all
DEPENDS
77_blackwell_fmha_fp8
77_blackwell_fmha_fp16
77_blackwell_fmha_gen_fp8
77_blackwell_fmha_gen_fp16
77_blackwell_mla_2sm_fp8
77_blackwell_mla_2sm_fp16
77_blackwell_mla_2sm_cpasync_fp8
77_blackwell_mla_2sm_cpasync_fp16
77_blackwell_mla_b2b_2sm_fp8
77_blackwell_mla_b2b_2sm_fp16
77_blackwell_fmha_bwd_fp8
77_blackwell_fmha_bwd_fp16
77_blackwell_fmha_bwd_sat_fp8
77_blackwell_fmha_bwd_sat_fp16
)
endif() endif()

View File

@ -157,8 +157,8 @@ struct CausalMask : NoMask {
TileShape const& tile_shape, TileShape const& tile_shape,
ProblemSize const& problem_size) { ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} }
template<class BlkCoord, class TileShape, class ProblemSize> template<class BlkCoord, class TileShape, class ProblemSize>

View File

@ -42,7 +42,7 @@ template<
class ElementAcc, class ElementAcc,
class TileShape, // Q, D, _ class TileShape, // Q, D, _
class StrideO, // Q, D, B class StrideO, // Q, D, B
class StrideLSE // Q, B class StrideLSE_ // Q, B
> >
struct Sm100FmhaFwdEpilogueTmaWarpspecialized { struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
@ -54,6 +54,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); // using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO; using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
struct TensorStorage { struct TensorStorage {
@ -79,6 +80,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
struct Params { struct Params {
TMA_O tma_store_o; TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
}; };
template<class ProblemShape> template<class ProblemShape>
@ -110,7 +114,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
); );
return { return {
tma_store_o tma_store_o,
args.ptr_LSE,
args.dLSE
}; };
} }
@ -119,6 +125,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
} }
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape> template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
store( store(

View File

@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row // Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert; NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
const int kReleasePipeCount = 10; // must be multiple of 2 const int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait(); order_s.wait();
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
} }
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive(); order_s.arrive();
} }
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum; row_sum = local_row_sum;
if (final_call) { if (final_call) {
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64 // good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut); const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma; typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO); Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{}))); Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{}))); Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{}))); Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX #ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) { for (int j = 0; j < size(tTMrO); j += 2) {
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// good values would be either 32 or 64 // good values would be either 32 or 64
const int kCorrectionTileSize = 16; const int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma; typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{}))); Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O; tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float2 scale_f32x2 = make_float2(scale, scale); float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) { auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
} }
} }
template<class BlkCoord, class ProblemShape, class TensorStorageEpi> template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
correction( correction(
BlkCoord const& blk_coord, BlkCoord const& blk_coord,
@ -951,7 +951,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) { PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
@ -961,7 +962,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
@ -1060,13 +1061,25 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
// F2FP // F2FP
// store to smem // store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state); pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state; ++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state); pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state; ++pipeline_epi_producer_state;
@ -1083,6 +1096,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state); pipeline_o.consumer_release(pipeline_o_consumer_state);

View File

@ -118,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>; using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>; using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
// compute S // compute S
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc), make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
SmemLayoutDQ{}(_, _, _0{}) SmemLayoutDQ{}(_, _, _0{})
); );
return Params{ return Params{
args.problem_shape, args.problem_shape,
args.mainloop, args.mainloop,
@ -452,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
auto tSTgK = cta_mma_kq.partition_A(gK); auto tSTgK = cta_mma_kq.partition_A(gK);
auto tSTgQ = cta_mma_kq.partition_B(gQ); auto tSTgQ = cta_mma_kq.partition_B(gQ);
auto tDPTgV = cta_mma_vdo.partition_A(gV); auto tDPTgV = cta_mma_vdo.partition_A(gV);
@ -477,7 +477,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
// set up lse and sum_odo // set up lse and sum_odo
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
} }
// load Q // load Q
if (cute::elect_one_sync()) { if (cute::elect_one_sync()) {
cute::copy( cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -520,7 +520,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
&mLSE(gmem_idx, blk_coord_batch), &mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q gmem_idx < Q
); );
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state; ++pipeline_load_compute_lse_producer_state;
@ -529,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
// load V // load V
if (cute::elect_one_sync()) { if (cute::elect_one_sync()) {
cute::copy( cute::copy(
@ -540,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
} }
// load dO // load dO
if (cute::elect_one_sync()) { if (cute::elect_one_sync()) {
cute::copy( cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -573,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
// load Q // load Q
if (cute::elect_one_sync()) { if (cute::elect_one_sync()) {
cute::copy( cute::copy(
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -584,7 +584,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
++pipeline_load_mma_q_producer_state; ++pipeline_load_mma_q_producer_state;
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
// load LSE // load LSE
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
@ -593,15 +593,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
&mLSE(gmem_idx, blk_coord_batch), &mLSE(gmem_idx, blk_coord_batch),
gmem_idx < Q gmem_idx < Q
); );
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_lse_producer_state; ++pipeline_load_compute_lse_producer_state;
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
// load dO // load dO
if (cute::elect_one_sync()) { if (cute::elect_one_sync()) {
cute::copy( cute::copy(
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
@ -612,7 +612,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
++pipeline_load_mma_do_producer_state; ++pipeline_load_mma_do_producer_state;
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
// load sum_OdO // load sum_OdO
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
@ -621,7 +621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
&mSumOdO(gmem_idx, blk_coord_batch), &mSumOdO(gmem_idx, blk_coord_batch),
gmem_idx < Q gmem_idx < Q
); );
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_load_compute_sum_odo_producer_state; ++pipeline_load_compute_sum_odo_producer_state;
@ -639,23 +639,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
int iter_count, int iter_count,
MainloopArguments const& mainloop_args, MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors, TensorStorage& shared_tensors,
PipelineLoadMmaQ& pipeline_load_mma_q, PipelineLoadMmaQ& pipeline_load_mma_q,
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
PipelineLoadMmaDO& pipeline_load_mma_do, PipelineLoadMmaDO& pipeline_load_mma_do,
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
PipelineMmaComputeS& pipeline_mma_compute_s, PipelineMmaComputeS& pipeline_mma_compute_s,
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
PipelineMmaComputeDP& pipeline_mma_compute_dp, PipelineMmaComputeDP& pipeline_mma_compute_dp,
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
PipelineMmaReduceDQ& pipeline_mma_reduce_dq, PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
PipelineComputeMmaP& pipeline_compute_mma_p, PipelineComputeMmaP& pipeline_compute_mma_p,
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
PipelineComputeMmaDS& pipeline_compute_mma_ds, PipelineComputeMmaDS& pipeline_compute_mma_ds,
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
auto [Q, K, D, HB] = problem_shape; auto [Q, K, D, HB] = problem_shape;
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
@ -685,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
tDVrP.data() = TmemAllocation::kP; tDVrP.data() = TmemAllocation::kP;
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
TiledMmaKQ tiled_mma_kq; TiledMmaKQ tiled_mma_kq;
TiledMmaVDO tiled_mma_vdo; TiledMmaVDO tiled_mma_vdo;
TiledMmaDSK tiled_mma_dsk; TiledMmaDSK tiled_mma_dsk;
@ -923,6 +923,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC const& coord, TensorC const& coord,
TensorShape const& tensor_shape) { TensorShape const& tensor_shape) {
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy( auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{}, Copy_Atom<UniversalCopy<uint128_t>, Element>{},
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})), make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
@ -930,21 +932,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
); );
auto thr_copy = copy_op.get_slice(_0{}); auto thr_copy = copy_op.get_slice(_0{});
auto tCg = thr_copy.partition_D(gmem); Tensor tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs)); Tensor tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord); Tensor tPc = thr_copy.partition_D(preds);
constexpr int R = decltype(tCr.layout())::rank; copy_if(copy_op, tPc, tCr, tCg);
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
} }
@ -1073,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
auto [Q, K, D, HB] = problem_shape; auto [Q, K, D, HB] = problem_shape;
// in tmem, S & P overlap // in tmem, S & P overlap
@ -1114,7 +1106,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST)); Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST)); Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
Tensor tTR_cDPT = split_wg(tTR_cDPT_p); Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT)); Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
@ -1152,20 +1144,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
fn(cute::false_type{}); fn(cute::false_type{});
} }
}; };
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) { warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
// compute P = softmax(S, LSE) // compute P = softmax(S, LSE)
cute::copy(tiled_t2r, tTR_tST, tTR_rST); cute::copy(tiled_t2r, tTR_tST, tTR_rST);
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) { if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
Mask{}.apply_mask(tTR_rST, [&](int i) { Mask{}.apply_mask(tTR_rST, [&](int i) {
auto c_transpose = tTR_cST(i); auto c_transpose = tTR_cST(i);
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
}, problem_shape); }, problem_shape);
} }
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E); ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
float2 softmax_scale_log2_e; float2 softmax_scale_log2_e;
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
@ -1184,16 +1176,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tTR_rST(i) = ::exp2f(out.x); tTR_rST(i) = ::exp2f(out.x);
tTR_rST(i+1) = ::exp2f(out.y); tTR_rST(i+1) = ::exp2f(out.y);
} }
auto tRT_rST = quantize(tTR_rST); auto tRT_rST = quantize(tTR_rST);
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::NamedBarrier( cutlass::arch::NamedBarrier(
kNumComputeWarps * NumThreadsPerWarp, kNumComputeWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransformBarrier cutlass::arch::ReservedNamedBarriers::TransformBarrier
).arrive_and_wait(); ).arrive_and_wait();
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
}); });
@ -1293,9 +1285,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store, PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
using X = Underscore; using X = Underscore;
auto [Q, K, D, HB] = problem_shape; auto [Q, K, D, HB] = problem_shape;
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
@ -1307,7 +1299,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tDQtDQ.data() = TmemAllocation::kDQ; tDQtDQ.data() = TmemAllocation::kDQ;
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
(_, _, _, _0{}, blk_coord_batch); (_, _, _, _0{}, blk_coord_batch);
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
@ -1376,7 +1368,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_index += 1; iter_index += 1;
} }
} }
CUTLASS_DEVICE void operator()(Params const& params, char* smem) { CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
int warp_idx = cutlass::canonical_warp_idx_sync(); int warp_idx = cutlass::canonical_warp_idx_sync();
@ -1561,7 +1553,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>(); auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>(); auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>(); auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
@ -1587,7 +1579,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
if (role == WarpRole::Load) { if (role == WarpRole::Load) {
warpgroup_reg_set<RegisterAllocation::kLoad>(); warpgroup_reg_set<RegisterAllocation::kLoad>();
load( load(
blk_coord, blk_coord,
problem_shape, problem_shape,
@ -1596,7 +1588,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
params.mainloop, params.mainloop,
params.mainloop_params, params.mainloop_params,
shared_storage.tensors, shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_producer_state, pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
pipeline_load_mma_do, pipeline_load_mma_do_producer_state, pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
@ -1608,7 +1600,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp(); __syncwarp();
mma( mma(
blk_coord, blk_coord,
problem_shape, problem_shape,
@ -1616,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
iter_count, iter_count,
params.mainloop, params.mainloop,
shared_storage.tensors, shared_storage.tensors,
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
@ -1629,7 +1621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
} }
else if (role == WarpRole::Compute) { else if (role == WarpRole::Compute) {
warpgroup_reg_set<RegisterAllocation::kCompute>(); warpgroup_reg_set<RegisterAllocation::kCompute>();
compute( compute(
blk_coord, blk_coord,
problem_shape, problem_shape,
@ -1660,7 +1652,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
} }
else if (role == WarpRole::Reduce) { else if (role == WarpRole::Reduce) {
warpgroup_reg_set<RegisterAllocation::kReduce>(); warpgroup_reg_set<RegisterAllocation::kReduce>();
reduce( reduce(
blk_coord, blk_coord,
problem_shape, problem_shape,
@ -1677,9 +1669,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
} }
else { else {
warpgroup_reg_set<RegisterAllocation::kEmpty>(); warpgroup_reg_set<RegisterAllocation::kEmpty>();
/* no-op */ /* no-op */
} }
} }

View File

@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>(); typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
CollectiveMainloop mainloop; CollectiveMainloop mainloop;
CollectiveEpilogue epilogue; CollectiveEpilogue epilogue{params.epilogue};
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
warpgroup_reg_set<NumRegsSoftmax>(); warpgroup_reg_set<NumRegsSoftmax>();
@ -407,7 +407,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s0_corr, pipeline_s0_corr_consumer_state,
pipeline_s1_corr, pipeline_s1_corr_consumer_state, pipeline_s1_corr, pipeline_s1_corr_consumer_state,
pipeline_mma_corr, pipeline_mma_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state,
pipeline_corr_epi, pipeline_corr_epi_producer_state pipeline_corr_epi, pipeline_corr_epi_producer_state,
epilogue
); );

View File

@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
ElementAcc sum_lse = 0; ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) { for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max); sum_lse = sum_lse + expf(local_lse[i] - lse_max);
} }
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max; ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) { if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse; gLSE(0) = global_lse;
} }

View File

@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale); mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0 && mLSE.data() != nullptr) {
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
} }

View File

@ -75,6 +75,8 @@ struct DeviceAllocation {
size_t size() const { return size_; } size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
void copy_from_host(const T* ptr, size_t sz) { void copy_from_host(const T* ptr, size_t sz) {
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
assert(ret == cudaSuccess); assert(ret == cudaSuccess);

View File

@ -280,7 +280,7 @@ auto make_iterator(T* ptr) {
/// Testbed utility types /// Testbed utility types
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
// Command line options parsing // Command line options parsing
struct Options { struct Options {

View File

@ -133,7 +133,7 @@ using TP = _8;
static constexpr int TP_ = TP{}; static constexpr int TP_ = TP{};
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \ #if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
// Distributed GEMM tiling/sharding schedule // Distributed GEMM tiling/sharding schedule
// Choices: // Choices:
@ -254,7 +254,8 @@ HostTensorB tensor_B_arr[TP_];
HostTensorD tensor_C_arr[TP_]; HostTensorD tensor_C_arr[TP_];
HostTensorD tensor_D_arr[TP_]; HostTensorD tensor_D_arr[TP_];
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types /// Testbed utility types
@ -347,7 +348,7 @@ struct Result {
}; };
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \ #if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation /// GEMM setup and evaluation
@ -805,17 +806,16 @@ int run(Options &options) {
return 0; return 0;
} }
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) { int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example // CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels.
// and must have compute capability at least 90. if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
// Some necessary cuda graph APIs were only introduced in CUDA 12.4. std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op. // Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0; return 0;
} }
@ -861,11 +861,11 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels // Evaluate CUTLASS kernels
// //
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
run(options); run(options);
#else #else
std::cerr std::cerr
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.4 or later." << std::endl; << "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl;
return 0; return 0;
#endif #endif

View File

@ -14,8 +14,8 @@ cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
### Minimum software ### Minimum software
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required. Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary This example specifically requires CUDA Toolkit 12.8 or newer, since that is the first version
CUDA graph APIs. supporting the Blackwell architecture.
### Hardware / driver settings ### Hardware / driver settings

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,50 @@
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
88_hopper_fmha
88_hopper_fmha.cu
)
if(NOT WIN32 AND NOT CUTLASS_CLANG_HOST_COMPILE)
set_property(
SOURCE 88_hopper_fmha.cu
PROPERTY COMPILE_FLAGS "--use_fast_math"
)
cutlass_example_add_executable(
88_hopper_fmha_fp8
88_hopper_fmha.cu
)
target_compile_definitions(
88_hopper_fmha_fp8
PRIVATE FP8)
endif()

View File

@ -0,0 +1,77 @@
# CUTLASS Hopper FMHA Example
This sample showcases how to implement fused multi-head attention (FMHA) using
CUTLASS for the NVIDIA Hopper architecture. At its heart, the forward pass of
FMHA is a GEMM-online softmax-GEMM fusion, whereas the backward pass is a slightly
more complex structure (basically, a GEMM-softmax-2xGEMM-2xGEMM fusion).
For more information please refer to the [Flash Attention 3 paper](https://arxiv.org/abs/2407.08608).
The forward pass kernel supports head dims 32, 64, 128, and 256 for fp16 and bf16 input data types,
and head dims 128, and 256 for fp8.
All kernels use the Tensor Memory Accelerator for loads.
Kernels with head dims 128 and 256 have warp-specialized cooperative schedules.
Backward pass kernels (fp16 only) support head dims 32, 64, and 128, and all support
warp-specialized cooperative schedules.
## Customization
### Mask Fusion
Similar to the [Blackwell FMHA example](../77_blackwell_fmha/README.md), attention masks such as
causal masking can be fused into the kernel. To modify the code for such fusions,
`collective/fmha_fusion.hpp` provides the easiest customization point.
The `before_softmax` function is called with the accumulator of the first GEMM and the logical
positions of those elements. It is well-suited for applying masks or activations.
### MHA Variants
Using CuTe, it is easy to represent the various attention variants.
Where regular multi-head attention's layout for the head dimension is (numHeads:headStride),
for single-head attention it is simply (1:0) everywhere,
for GQA it is normal in Q and (numHeads/numGroups,numGroups:headStride,0) in KV,
and for MQA it is normal for Q and (numHeads:0) in KV.
As such, beyond general stride handling, no additional work is needed to support these,
and the example will just demonstrate regular multi-head attention.
### FP8
The warp-specialized forward kernel supports FP8 computation with both FP32 and FP16
accumulation for the Q*K product. They can be enabled in the runner by defining FP8.
## Performance
Forward pass kernels can generally come close to that of FA3, but backward pass
kernels are more limited in performance and are not expected to reach the same level of performance
as FA3.
# Copyright
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
```
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -0,0 +1,863 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_collective_load.hpp"
#include "../collective/fmha_collective_softmax.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::collective {
template<
typename Element_,
typename ElementAccumulator_,
typename TileShape_, // BlockQO, BlockKV, BlockHead
class Fusion,
class... Options
>
struct FmhaBwdMainloopTmaWarpSpecialized {
using Element = Element_;
using ElementAccumulator = ElementAccumulator_;
using TileShape = TileShape_;
static constexpr bool kIsPersistent = false;
static const int NumLoadWarpGroups = 1;
static constexpr int NumMmaWarpGroups = 2;
static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups;
static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */;
static const int kOuterLoads = 2;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
using ClusterShape = Shape<_1, _1, _1>;
static_assert(StagesQ::value >= 2);
static_assert(Stages::value >= 2 * NumMmaWarpGroups);
// 16B alignment lets us use TMA
static constexpr int Alignment = 16 / sizeof(Element);
using TileShapeNM = Shape< // (N,M,D)
decltype(tuple_element_t<1, TileShape>{} / Int<NumMmaWarpGroups>{}),
tuple_element_t<0, TileShape>,
tuple_element_t<2, TileShape>>;
using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M)
using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N)
using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeNM, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeND, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeND, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment, // from smem, might matter (?)
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
ElementAccumulator,
TileShapeMD, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using TiledMmaNM = typename CollectiveMmaNM::TiledMma;
using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma;
using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{}));
using TiledMmaND = TiledMmaND_RS;
using TiledMmaMD = typename CollectiveMmaMD::TiledMma;
using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB;
using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA;
using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA;
using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB;
//using SmemLayoutDQ = Layout<
// Shape<
// tuple_element_t<0, TileShapeMD>,
// Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>,
// _2
// >,
// Stride<
// _4,
// Stride<decltype(tuple_element_t<0, TileShapeMD>{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>,
// decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
// >>;
using SmemLayoutDQ_0 = Layout<
Shape<
tuple_element_t<0, TileShapeMD>,
tuple_element_t<1, TileShapeMD>,
_2
>,
Stride<
tuple_element_t<1, TileShapeMD>,
_1,
decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
>>;
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>());
using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{}));
using SmemLayoutDQ = SmemLayoutDQ_1;
using PipelineDQ = cutlass::PipelineAsync<2>;
using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int<NumMmaWarpGroups>{}));
using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom<Element>{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{}));
using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB;
using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB;
using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB;
using SmemLayoutLSE = Layout<Shape<tuple_element_t<1, TileShapeNM>, Int<StageCount>>>;
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
using TileShapePV = TileShapeND; // To work with the kernel level
using TiledMmaPV = TiledMmaND;
static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator);
static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
struct SharedStorage {
// One for each consumer WG
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutKp>> smem_kp;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
// Loaded by producer, consumed by both WGs
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQp>> smem_qp;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDOp>> smem_dop;
};
// Accumulated into by both consumers, potentially loaded, potentially written
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutDQ>> smem_dq;
union {
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_lse;
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_sumOdO;
};
};
struct Arguments {
const Element* ptr_Q;
cute::tuple<int, int, int, _1> dQ;
const Element* ptr_K;
cute::tuple<int, int, int, _1> dK;
const Element* ptr_V;
cute::tuple<int, int, int, _1> dV;
const Element* ptr_dO;
cute::tuple<int, int, int, _1> dDO;
const ElementAccumulator* ptr_LSE;
cute::tuple<int, int, _1> dLSE;
const ElementAccumulator* ptr_sum_OdO;
cute::tuple<int, int, _1> dSumOdO;
ElementAccumulator* ptr_dQ;
cute::tuple<int, int, int, _1> dDQ;
};
using TMA_Q = typename CollectiveMmaNM::Params::TMA_B;
using TMA_K = typename CollectiveMmaNM::Params::TMA_A;
using TMA_V = typename CollectiveMmaNM::Params::TMA_A;
using TMA_DO = typename CollectiveMmaNM::Params::TMA_B;
using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{})));
using TMA_ODO = TMA_LSE;
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{})));
using LoadQ = CollectiveLoadTma<
LoadKind::kBwdM,
MainloopPipeline,
Element,
SmemLayoutQ,
TMA_Q
>;
using LoadK = CollectiveLoadTma<
LoadKind::kBwdN,
MainloopPipelineQ,
Element,
SmemLayoutK,
TMA_K
>;
using LoadV = CollectiveLoadTma<
LoadKind::kBwdN,
MainloopPipelineQ,
Element,
SmemLayoutV,
TMA_V
>;
using LoadDO = CollectiveLoadTma<
LoadKind::kBwdM,
MainloopPipeline,
Element,
SmemLayoutDO,
TMA_DO
>;
using LoadLSE = CollectiveLoadTma<
LoadKind::kBwdScalar,
MainloopPipeline,
ElementAccumulator,
SmemLayoutLSE,
TMA_LSE
>;
using LoadODO = CollectiveLoadTma<
LoadKind::kBwdScalar,
MainloopPipeline,
ElementAccumulator,
SmemLayoutLSE,
TMA_ODO
>;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
TMA_DO tma_load_do;
TMA_LSE tma_load_lse;
TMA_ODO tma_load_odo;
TMA_DQ tma_red_dq;
float scale_softmax;
float scale_softmax_log2;
};
static_assert(size(TiledMmaNM{}) == size(TiledMmaND{}));
static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{}));
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
return true
&& (get<4>(problem_size) <= get<2>(TileShape{}))
&& ((get<4>(problem_size) % Alignment) == 0)
&& ((get<2>(problem_size) % Alignment) == 0)
;
}
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK)));
auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ)));
auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
typename CollectiveMmaNM::Arguments {
args.ptr_K, dK,
args.ptr_Q, dQ,
}, /*workspace=*/ nullptr);
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV)));
auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO)));
auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
typename CollectiveMmaNM::Arguments {
args.ptr_V, dV,
args.ptr_dO, dDO,
}, /*workspace=*/ nullptr);
TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{}));
TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{}));
TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{}));
return Params{
params_nm_kq.tma_load_b,
params_nm_kq.tma_load_a,
params_nm_vdo.tma_load_a,
params_nm_vdo.tma_load_b,
tma_load_lse, tma_load_odo,
tma_red_dq,
1.0f / (float) std::sqrt(get<4>(problem_size)),
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size)))
};
}
template<class BlkCoord, class ProblemSize>
CUTLASS_DEVICE
auto
get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) {
return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor());
}
template<bool kLoadOuter, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_kv_maybe_q(
int block_rank_in_cluster,
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner,
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
// Load pattern:
// K0 V0 K1 V1
// Q0 DO0 Q1 DO1 Q2 DO2 ...
// K0 Q0 V0 K1 DO0 V1 ...
int lane_predicate = cute::elect_one_sync();
int outer_tile_count = NumMmaWarpGroups;
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count);
uint16_t mcast_mask_b = 0;
LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q};
auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do};
auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse};
auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO};
auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
outer_tile_count *= 2; // K & V
inner_tile_count *= 4; // Q & dO & LSE & sumOdO
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 4;
++inner_tile_iter;
}
if constexpr (kLoadOuter) {
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
if constexpr (! kLoadOuter) {
if (do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
}
if constexpr (kLoadOuter) {
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
if constexpr (kLoadOuter) {
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
if constexpr (kLoadOuter) {
while (outer_tile_count > 0) {
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
}
CUTLASS_PRAGMA_NO_UNROLL
while (inner_tile_count > 0) {
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 4;
++inner_tile_iter;
}
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
}
}
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_maybe_q(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
// Load pattern:
// K0 V0 K1 V1
// Q0 DO0 Q1 DO1 Q2 DO2 ...
// K0 Q0 V0 K1 DO0 V1 ...
int lane_predicate = cute::elect_one_sync();
int outer_tile_count = NumMmaWarpGroups;
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
outer_tile_count *= 2; // K & V
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
if (do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
while (outer_tile_count > 0) {
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
}
}
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
CUTLASS_DEVICE void
reduce(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer,
SharedStorage& storage)
{
int lane_predicate = cute::elect_one_sync();
Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size));
Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{});
Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord));
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
auto block_tma = params.tma_red_dq.get_slice(_0{});
Tensor tDQsDQ = block_tma.partition_S(sDQ);
Tensor tDQgDQ = block_tma.partition_D(gDQ);
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
int g_index = 0;
auto smem_pipe_release_reducer = smem_pipe_read_reducer;
bool first = true;
while (inner_tile_count > 0) {
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 1;
++g_index;
}
if (inner_tile_count == 0) break;
pipeline_reducer.consumer_wait(smem_pipe_read_reducer);
if (lane_predicate == 1) {
tma_store_wait<1>();
}
if (! first) {
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
++smem_pipe_release_reducer;
} else {
first = false;
}
if (lane_predicate == 1) {
copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index));
tma_store_arrive();
}
++smem_pipe_read_reducer;
--inner_tile_count;
++g_index;
}
if (lane_predicate) {
tma_store_wait<0>();
}
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
++smem_pipe_release_reducer;
}
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
CUTLASS_DEVICE auto
compute(
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner,
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer,
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
SharedStorage& storage,
MathWgOrderBarrier& math_wg_order_barrier)
{
TiledMmaND tiled_mma_nd;
Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
clear(acc_DV);
Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
clear(acc_DK);
int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup;
PipelineState smem_pipe_release_inner = smem_pipe_read_inner;
pipeline_outer.consumer_wait(smem_pipe_read_outer);
PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer;
++smem_pipe_read_outer;
pipeline_outer.consumer_wait(smem_pipe_read_outer);
PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer;
int inner_tile_count = get_inner_tile_count(wg_coord, problem_size);
TiledMmaNM tiled_mma_nm;
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx);
Tensor tSsK = thr_mma_nm.partition_A(sK);
Tensor tSsQ = thr_mma_nm.partition_B(sQ);
Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK);
Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{});
Tensor tDPsV = thr_mma_nm.partition_A(sV);
Tensor tDPsDO = thr_mma_nm.partition_B(sDO);
Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV);
Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO);
auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx);
Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{});
Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp);
Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO);
Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{});
Tensor tDK_sQ = thr_mma_nd.partition_B(sQp);
Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ);
int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0);
TiledMmaMD tiled_mma_md;
auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx);
Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{});
Tensor tDQsDS = thr_mma_md.partition_A(sDS);
Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS);
Tensor tDQrDS = tDQrDS_full(_,_,_,_);
Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{});
Tensor tDQsK = thr_mma_md.partition_B(sKp);
Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK);
Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
Tensor tSsLSE = thr_mma_nm.partition_C(sLSE);
Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
Tensor tDPsODO = thr_mma_nm.partition_C(sODO);
Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{}));
Tensor tScS = thr_mma_nm.partition_C(cS);
int n_block = get<1>(wg_coord);
tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{});
// Transpose
Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS))));
Tensor sDSp = sDSp_full(_,_,_);
Tensor tDPsDS = thr_mma_nm.partition_C(sDSp);
auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx);
Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp);
Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp);
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
auto tDQsDQ_full = thr_mma_md.partition_C(sDQ);
auto smem_pipe_read_k_other = smem_pipe_read_k;
smem_pipe_read_k_other.advance(2);
int k_index = 0;
while (inner_tile_count > 0) {
while (inner_tile_count > 0) {
if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
break;
}
inner_tile_count -= 1;
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
k_index += 1;
}
if (inner_tile_count == 0) break;
pipeline_inner.consumer_wait(smem_pipe_read_inner);
PipelineState smem_pipe_read_q = smem_pipe_read_inner;
++smem_pipe_read_inner;
PipelineState smem_pipe_read_do = smem_pipe_read_inner;
++smem_pipe_read_inner;
// GEMM KQ -> S
Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
warpgroup_fence_operand(acc_S);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S);
warpgroup_commit_batch();
pipeline_inner.consumer_wait(smem_pipe_read_do);
// GEMM VdO -> dP
Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
warpgroup_fence_operand(acc_DP);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP);
warpgroup_commit_batch();
Tensor reg_LSE = make_fragment_like<ElementAccumulator>(acc_S);
for (int i = 0; i < size(reg_LSE); i++) {
reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i);
}
Tensor reg_ODO = make_fragment_like<ElementAccumulator>(acc_S);
if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) {
for (int i = 0; i < size(reg_ODO); i++) {
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
}
}
warpgroup_wait<1>();
warpgroup_fence_operand(acc_S);
math_wg_order_barrier.wait();
// Compute S -> P
Fusion{}.before_softmax(acc_S, tScS, problem_size);
auto acc_P = make_fragment_like<ElementAccumulator>(acc_S);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_P); i++) {
acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i));
}
math_wg_order_barrier.arrive();
if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) {
for (int i = 0; i < size(reg_ODO); i++) {
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
}
}
warpgroup_wait<0>();
warpgroup_fence_operand(acc_DP);
// Compute dP P -> dS
auto acc_DS = make_fragment_like<Element>(acc_DP);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_DS); i++) {
// We could move the scale out and into the respective epilogues (or a final scaling step)
acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i));
}
// GEMM PdO -> dV
auto op_P = make_acc_into_op<Element>(acc_P, typename TiledMmaND::LayoutA_TV{});
warpgroup_fence_operand(acc_DV);
warpgroup_fence_operand(op_P);
warpgroup_arrive();
cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV);
warpgroup_commit_batch();
// Store dS to smem dS'
if (wg_idx == 0) math_wg_order_barrier.wait();
auto recast_bits = [](auto sz, auto t) {
return recast<uint_bit_t<decltype(sz)::value>>(t);
};
auto tDPsDS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, tDPsDS);
auto acc_DS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, acc_DS);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_DS_v); i++) {
tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i);
}
cutlass::arch::fence_view_async_shared();
if (wg_idx == 0) math_wg_order_barrier.arrive();
// GEMM dS Q -> dK
if (wg_idx == 1) {
math_wg_order_barrier.wait();
// GEMM dS' K -> dQ
Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{}));
warpgroup_fence_operand(acc_DQ);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ);
cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ);
warpgroup_commit_batch();
warpgroup_fence_operand(acc_DK);
warpgroup_arrive();
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
warpgroup_commit_batch();
warpgroup_wait<1>();
warpgroup_fence_operand(acc_DK);
warpgroup_wait<1>();
warpgroup_fence_operand(acc_DQ);
math_wg_order_barrier.arrive();
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index());
// Store dQ to smem dQ'
// Invoke TMA reduce on dQ'
using Vec = uint_bit_t<sizeof_bits_v<ElementAccumulator> * 2>;
auto tDQsDQ_v = recast<Vec>(tDQsDQ);
auto acc_DQ_v = recast<Vec>(acc_DQ);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_DQ_v); i++) {
tDQsDQ_v(i) = acc_DQ_v(i);
}
cutlass::arch::fence_view_async_shared();
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
++smem_pipe_write_reducer;
} else {
warpgroup_fence_operand(acc_DK);
warpgroup_arrive();
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
warpgroup_commit_batch();
warpgroup_wait<1>();
warpgroup_fence_operand(acc_DK);
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
++smem_pipe_write_reducer;
}
--inner_tile_count;
pipeline_inner.consumer_release(smem_pipe_release_inner);
++smem_pipe_release_inner;
pipeline_inner.consumer_release(smem_pipe_release_inner);
++smem_pipe_release_inner;
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
k_index += 1;
}
pipeline_outer.consumer_release(smem_pipe_read_k);
pipeline_outer.consumer_release(smem_pipe_read_outer);
pipeline_reducer.producer_tail(smem_pipe_write_reducer);
++smem_pipe_read_outer;
warpgroup_wait<0>();
warpgroup_fence_operand(acc_DK);
warpgroup_fence_operand(acc_DV);
return make_tuple(acc_DK, acc_DV);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,140 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
enum class LoadKind {
kQ, kK, kV,
kBwdN, kBwdM, kBwdScalar
};
template<
LoadKind kKind,
class Pipeline,
class Element,
class SmemLayout,
class TMA
>
struct CollectiveLoadTma {
using Params = TMA;
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayout>>;
using PipelineState = typename cutlass::PipelineState<Pipeline::Stages>;
Params const& params;
Pipeline& pipeline;
SharedStorage& storage;
CUTLASS_DEVICE
CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage)
: params(params), pipeline(pipeline), storage(storage) {}
template<class ProblemSize, class TileShape, class BlockCoord>
CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape,
BlockCoord const& blk_coord, int loop_count
) {
using X = Underscore;
if constexpr (kKind == LoadKind::kK) {
Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord));
return gK;
} else if constexpr (kKind == LoadKind::kQ) {
Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord));
return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout());
} else if constexpr (kKind == LoadKind::kV) {
Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size)));
Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord));
return gV;
} else if constexpr (kKind == LoadKind::kBwdN) {
Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout());
} else if constexpr (kKind == LoadKind::kBwdM) {
Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
return g;
} else if constexpr (kKind == LoadKind::kBwdScalar) {
Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size));
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, X>{});
Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord));
return g;
}
}
template<class ClusterRank, class ProblemSize, class TileShape, class BlockCoord>
CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster,
ProblemSize const& problem_size, TileShape const& tile_shape,
BlockCoord const& block_coord, int loop_count
) {
Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count);
Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{});
auto block_tma = params.get_slice(block_rank_in_cluster);
Tensor ts = block_tma.partition_D(s);
Tensor tg = block_tma.partition_S(g);
return make_tuple(tg, ts);
}
template<bool kAdvanceIterator=true, bool kAdvancePipe=true, bool kAcquireBarrier=true, class TileIterator, class State>
CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state,
PipelineState& smem_pipe_write,
int lane_predicate, int& tile_count, uint16_t mcast_mask = 0
) {
if ((lane_predicate == 1) && (tile_count > 0)) {
if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write);
using BarrierType = typename Pipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
if constexpr (kKind == LoadKind::kBwdScalar) {
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index()));
} else {
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index()));
}
if constexpr (kAdvancePipe) ++smem_pipe_write;
if constexpr (kAdvanceIterator) ++tile_iter;
}
--tile_count;
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,305 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "../collective/fmha_common.hpp"
namespace cutlass::fmha::collective {
template<
class ElementAccumulator,
class Fusion,
class Params
>
struct CollectiveSoftmax {
Params const& params;
CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {}
using SumType = float;
using MaxType = ElementAccumulator;
template<class AccPV, class TiledMmaPV>
CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) {
Tensor s_max = make_fragment_like<MaxType>(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout())));
Tensor a_sum = make_fragment_like<SumType>(s_max);
return make_tuple(s_max, a_sum);
}
CUTLASS_DEVICE float overload_exp2(float f) {
return ::exp2f(f);
}
CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) {
auto a = f.raw();
decltype(a) d;
asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a));
return cutlass::half_t::bitcast(d);
}
CUTLASS_DEVICE float overload_max(float a, float b) {
return ::max(a, b);
}
CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) {
return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())};
}
CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) {
return f.to_half();
}
CUTLASS_DEVICE float overload_to_native(float f) {
return f;
}
template<class AccQK, class TiledMmaQK, class CountQK, class State, class ProblemShape>
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) {
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
// Linear reduction is faster for the first iteration
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = acc_qk_mn(i, 0);
}
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < size<1>(acc_qk_mn); j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
}
}
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
}
}
});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
MaxType scale_max = scale * local_max;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
}
}
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
if constexpr (kUseFusion) {
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
}
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
Tensor s_max_prev = make_fragment_like(s_max);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max_prev(i) = s_max(i);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
// Linear reduction is faster here, as well
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
}
}
// reduce max
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride<r>(reduction_target_qk) * j));
}
}
});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i);
float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2);
a_sum(i) *= scale;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
acc_pv_mn(i, j) *= scale;
}
}
}
template<class AccQK_MN, class State>
CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) {
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<0>(acc_qk_mn); j++) {
float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j);
float scale_max = params.scale_softmax_log2 * local_max;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<1>(acc_qk_mn); k++) {
acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max);
a_sum(j) += acc_qk_mn(j, k);
}
}
}
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
if constexpr (kUseFusion) {
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
}
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
Tensor s_max_prev = make_fragment_like(s_max);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
s_max_prev(i) = s_max(i);
// Linear reduction is faster here, as well
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
}
// reduce max
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
}
});
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
MaxType scale_max = scale * local_max;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
}
MaxType s_max_cur = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale);
a_sum(i) *= scale_pv;
using ElementPV = typename AccPV::value_type;
ElementPV scale_pv_ele = static_cast<ElementPV>(scale_pv);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
acc_pv_mn(i, j) *= scale_pv_ele;
}
a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
}
}
template<class State, class AccPV, class TiledMmaPV>
CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) {
auto& s_max = get<0>(state);
auto& a_sum = get<1>(state);
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
auto reduction_target = reduction_target_n(tiled_mma_pv);
constexpr int red_rank = decltype(rank(reduction_target))::value;
for_each(make_seq<red_rank>{}, [&](auto r) {
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < shape<r>(reduction_target); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride<r>(reduction_target) * j);
}
}
});
Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
Tensor lse = make_fragment_like(a_sum);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(acc_mn); i++) {
float sum = a_sum(i);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum);
lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum);
float scale = params.rp_dropout * inv_sum;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(acc_mn); j++) {
acc_mn(i, j) *= scale;
}
}
return lse;
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,526 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_collective_load.hpp"
#include "../collective/fmha_collective_softmax.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
using cutlass::fmha::kernel::Tag;
using cutlass::fmha::kernel::find_option_t;
template<
typename Element_,
typename ElementAccumulator_,
typename TileShape_, // BlockQO, BlockKV, BlockHead
class Fusion,
class... Options
>
struct FmhaMainloopTma {
using Element = Element_;
using ElementAccumulator = ElementAccumulator_;
using TileShape = TileShape_;
// Options
using kClusterM = find_option_t<Tag::kClusterM, Int<1>, Options...>;
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<4>, Options...>::value;
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<1>, Options...>::value;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
using ClusterShape = Shape<kClusterM, _1, _1>;
// 16B alignment lets us use TMA
static constexpr int Alignment = 16 / sizeof(Element);
using TileShapeQK = TileShape;
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
using LayoutQKV = cute::tuple<int, _1, cute::tuple<int, int>>;
using LayoutQ = LayoutQKV;
using LayoutK = LayoutQKV;
using LayoutV = LayoutQKV;
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, LayoutQ, Alignment,
Element, LayoutK, Alignment,
ElementAccumulator,
TileShapeQK, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, LayoutK, Alignment,
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
ElementAccumulator,
TileShapePV, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
using TileShapeOut = TileShapePV;
using TiledMmaOut = TiledMmaPV;
using ElementOut = ElementAccumulator;
struct SharedStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
};
struct Arguments {
const Element* ptr_Q;
LayoutQ dQ;
const Element* ptr_K;
LayoutK dK;
const Element* ptr_V;
LayoutV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
float scale_softmax;
float scale_softmax_log2;
float rp_dropout;
};
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kQ,
MainloopPipelineQ,
Element,
SmemLayoutQ,
TMA_Q
>;
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kK,
MainloopPipeline,
Element,
SmemLayoutK,
TMA_K
>;
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kV,
MainloopPipeline,
Element,
SmemLayoutV,
TMA_V
>;
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{});
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
return true
&& (get<4>(problem_size) <= get<2>(TileShape{}))
&& ((get<4>(problem_size) % Alignment) == 0)
&& ((get<2>(problem_size) % Alignment) == 0)
;
}
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
typename CollectiveMmaQK::Arguments {
args.ptr_Q, args.dQ,
args.ptr_K, args.dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
typename CollectiveMmaPV::Arguments {
args.ptr_K, args.dK, // never used, dummy
args.ptr_V, select<1,0,2>(args.dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b,
1.0f / (float) std::sqrt(get<4>(problem_size)),
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
1.0f
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
compute(
int block_rank_in_cluster,
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q,
SharedStorage& storage)
{
int warp_idx = cutlass::canonical_warp_idx_sync();
int thread_idx = threadIdx.x;
PipelineState smem_pipe_release = smem_pipe_read;
[[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1);
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
// Set predicate for the lowest lane_id in the warp
int lane_predicate = cute::elect_one_sync();
// Issue TmaLoads (Prologue fetches)
if (warp_idx == 0) {
auto q_tile_iter = cute::make_coord_iterator(1);
int q_tile_count = 1;
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
// Loop over K elems
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
int k_tile_count_tma = 2 * fusion_tile_count;
uint16_t mcast_mask_b = 0;
if (warp_idx == 0 && lane_predicate == 1) {
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < StageCount; i++) {
if (i % 2 == 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
} else {
load_v.template step<true>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
}
}
TiledMmaQK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
// Mainloop setup QK
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
// Prepare: MMA PV
TiledMmaPV tiled_mma_pv;
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
// Mainloop setup PV
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
pipeline_q.consumer_wait(smem_pipe_read_q);
// mapping into QK accumulator
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
Tensor tPcP = thr_mma_qk.partition_C(cP);
int m_block = get<0>(blk_coord);
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
// Allocate PV acc
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulator, Fusion, decltype(params)> softmax{params};
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
if (true)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
Tensor acc_qk_fixed = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})));
Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
acc_qk_input(i) = static_cast<Element>(acc_qk(i));
}
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
//
// Advance the pipe
//
// Advance consumer pipeline
++smem_pipe_read;
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
if (warp_idx == 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
softmax.template step_interleave_begin<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tOrV); i++) {
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(acc_qk_element_mk); j++) {
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
}
warpgroup_arrive();
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(tOrV); j++) {
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
}
}
warpgroup_commit_batch();
// Wait for the pipeline MMAs to drain
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
if (warp_idx == 0) {
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
if (warp_idx == 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<2>(tOrV); i++) {
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(acc_qk_element_mk); j++) {
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
}
warpgroup_arrive();
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<1>(tOrV); j++) {
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
}
}
warpgroup_commit_batch();
// Wait for the pipeline MMAs to drain
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
if (warp_idx == 0) {
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
}
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_pv);
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
return make_tuple(acc_pv, lse);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,560 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_collective_load.hpp"
#include "../collective/fmha_collective_softmax.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
using cutlass::fmha::kernel::Tag;
using cutlass::fmha::kernel::find_option_t;
template<
class Element_,
class ElementAccumulatorQK_,
class ElementAccumulatorPV_,
class TileShape_, // SeqQ, SeqKV, Head
class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches)
class Fusion,
class... Options
>
struct FmhaMainloopTmaWarpSpecialized {
using Element = Element_;
using ElementAccumulatorQK = ElementAccumulatorQK_;
using ElementAccumulatorPV = ElementAccumulatorPV_;
using TileShape = TileShape_;
using LayoutQ = LayoutQ_;
using LayoutK = LayoutK_;
using LayoutV = LayoutV_;
// Options
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
static constexpr bool kIsMainloopLocked = find_option_t<Tag::kIsMainloopLocked, false_type, Options...>::value;
static constexpr int NumLoadWarpGroups = 1;
static constexpr int NumMmaWarpGroups = find_option_t<Tag::kNumMmaWarpGroups, Int<2>, Options...>::value;
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<5>, Options...>::value;
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<NumMmaWarpGroups>, Options...>::value;
static const int kOuterLoads = 1;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
using ClusterShape = Shape<_1, _1, _1>;
static_assert(StagesQ::value >= NumMmaWarpGroups);
static_assert(Stages::value >= 2);
// 16B alignment lets us use TMA
static constexpr int Alignment = 16 / sizeof(Element);
using TileShapeQK = Shape<
decltype(tuple_element_t<0, TileShape>{} / Int<NumMmaWarpGroups>{}),
tuple_element_t<1, TileShape>,
tuple_element_t<2, TileShape>>;
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Element, LayoutQ, Alignment,
Element, LayoutK, Alignment,
ElementAccumulatorQK,
TileShapeQK, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, LayoutK, Alignment,
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
ElementAccumulatorPV,
TileShapePV, ClusterShape, Stages,
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element);
using TileShapeOut = TileShapePV;
using TiledMmaOut = TiledMmaPV;
using ElementOut = ElementAccumulatorPV;
struct SharedStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
};
struct Arguments {
const Element* ptr_Q;
LayoutQ dQ;
const Element* ptr_K;
LayoutK dK;
const Element* ptr_V;
LayoutV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
float scale_softmax;
float scale_softmax_log2;
float rp_dropout;
};
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kQ,
MainloopPipelineQ,
Element,
SmemLayoutQ,
TMA_Q
>;
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kK,
MainloopPipeline,
Element,
SmemLayoutK,
TMA_K
>;
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
cutlass::fmha::collective::LoadKind::kV,
MainloopPipeline,
Element,
SmemLayoutV,
TMA_V
>;
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
return true
&& (get<4>(problem_size) <= get<2>(TileShape{}))
&& ((get<4>(problem_size) % Alignment) == 0)
&& ((get<2>(problem_size) % Alignment) == 0)
;
}
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
typename CollectiveMmaQK::Arguments {
args.ptr_Q, args.dQ,
args.ptr_K, args.dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
typename CollectiveMmaPV::Arguments {
args.ptr_K, args.dK, // never used, dummy
args.ptr_V, select<1,0,2>(args.dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b,
1.0f / (float) std::sqrt(get<4>(problem_size)),
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
1.0f
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<bool kLoadQ, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_kv_maybe_q(
int block_rank_in_cluster,
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline, PipelineState& smem_pipe_write,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
int lane_predicate = cute::elect_one_sync();
uint16_t mcast_mask_b = 0;
if (lane_predicate == 1) {
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
}
}
}
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
[[maybe_unused]] int q_tile_count = NumMmaWarpGroups;
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
int k_tile_count = 2 * fusion_tile_count;
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
if constexpr (kLoadQ) {
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
if constexpr (kLoadQ) {
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
if constexpr (! kLoadQ) {
if (do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
}
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
if constexpr (kLoadQ) {
while (q_tile_count > 0) {
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
}
}
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
}
}
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
CUTLASS_DEVICE void
load_maybe_q(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
SharedStorage& storage,
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
{
int lane_predicate = cute::elect_one_sync();
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
CUTLASS_PRAGMA_UNROLL
for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) {
int count = 1;
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count);
if (q_tile_count == 0 && do_barrier) {
load_warp_barrier.arrive();
load_warp_barrier.wait(/*phase=*/ 0);
do_barrier = false;
}
}
}
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
CUTLASS_DEVICE void
reduce(
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
SharedStorage& storage)
{ /* no-op */ }
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
CUTLASS_DEVICE auto
compute(
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
Params const& params, ProblemShape const& problem_size,
MainloopPipeline& pipeline, PipelineState& smem_pipe_read,
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q,
MainloopPipelineReducer&, PipelineStateReducer&,
SharedStorage& storage,
MathWgOrderBarrier& math_wg_order_barrier)
{
int thread_idx = int(threadIdx.x);
PipelineState smem_pipe_release = smem_pipe_read;
PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
TiledMmaQK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
// Mainloop setup QK
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
// Prepare: MMA PV
TiledMmaPV tiled_mma_pv;
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
// Mainloop setup PV
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
pipeline_q.consumer_wait(smem_pipe_read_q);
// mapping into QK accumulator
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
Tensor tPcP = thr_mma_qk.partition_C(cP);
int m_block = get<0>(wg_coord);
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
// Allocate PV acc
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulatorQK, Fusion, decltype(params)> softmax{params};
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
if (true)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
math_wg_order_barrier.wait();
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
math_wg_order_barrier.arrive();
++smem_pipe_read;
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
pipeline.consumer_wait(smem_pipe_read);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
// Advance consumer pipeline
++smem_pipe_read;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
softmax.template step<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
pipeline.consumer_wait(smem_pipe_read, tok);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0)
{
--k_tile_count;
// Allocate QK acc
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
pipeline.consumer_wait(smem_pipe_read);
// MMA QK
warpgroup_fence_operand(acc_qk);
warpgroup_arrive();
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
warpgroup_commit_batch();
++smem_pipe_read;
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_qk);
warpgroup_fence_operand(acc_pv);
//if constexpr (kIsPersistent)
// if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
pipeline.consumer_wait(smem_pipe_read, tok);
// MMA PV
warpgroup_fence_operand(acc_pv);
warpgroup_fence_operand(acc_qk_fixed);
warpgroup_arrive();
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
warpgroup_commit_batch();
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
++smem_pipe_read;
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
}
if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q);
// Wait for the pipeline MMAs to drain
warpgroup_wait<0>();
warpgroup_fence_operand(acc_pv);
if (kIsPersistent) pipeline.consumer_release(smem_pipe_release);
++smem_pipe_release;
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
return make_tuple(acc_pv, lse);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,245 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/kernel_hardware_info.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
constexpr int rA = decltype(rank(tA))::value;
constexpr int rB = decltype(rank(tB))::value;
constexpr int rC = decltype(rank(tC))::value;
if constexpr (rA == 2 && rB == 2 && rC == 1) {
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<1>(tA); k_block++) {
cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC);
atom.accumulate_ = GMMA::ScaleOut::One;
}
} else {
static_assert(rA == 3 && rB == 3 && rC == 3);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
atom.accumulate_ = GMMA::ScaleOut::One;
}
}
}
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
atom.accumulate_ = GMMA::ScaleOut::Zero;
gemm_reset_zero_acc(atom, tA, tB, tC);
}
template<typename T, typename Fn>
CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) {
if constexpr (decltype(size(t) % _2{} == _0{})::value) {
auto partial = make_tensor<typename T::value_type>(size(t) / _2{});
CUTE_UNROLL
for (int i = 0; i < size(partial); i++) {
partial(i) = fn(t(i), t(i + size(partial)));
}
return reduce(partial, fn);
} else {
auto result = t(_0{});
CUTE_UNROLL
for (int i = 1; i < size(t); i++) {
result = fn(result, t(i));
}
return result;
}
}
struct fmha_max {
CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); }
};
template<typename Threshold, typename Source, typename Reference>
inline auto __device__ constexpr layout_separate(Threshold const& thr,
Source const& src, Reference const& ref) {
auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
if constexpr(decltype(r < thr)::value) {
return s;
} else {
return make_layout(_1{}, _0{});
}
}));
auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
if constexpr(decltype(r >= thr)::value) {
return s;
} else {
return make_layout(_1{}, _0{});
}
}));
return make_layout(lt, ge);
}
template<typename TiledMma, typename Acc>
inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) {
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{}));
auto V_M = get<0>(separated);
auto V_N = get<1>(separated);
return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc)));
}
template<typename TiledMma, typename Acc>
inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) {
return layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{}));
}
template<typename TiledMma, typename Acc>
inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) {
return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout()));
}
template<typename TiledMma>
inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) {
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
make_layout(shape<0>(typename TiledMma::LayoutC_TV{})),
stride<0>(typename TiledMma::LayoutC_TV{}));
return get<1>(separated);
}
template<template<cute::GMMA::Major, cute::GMMA::Major, cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<tA, tB, sA, sB>> const& tiled_mma) {
using Atom = cute::MMA_Atom<Primitive<tA, tB, sA, sB>>;
using ElementA = typename Atom::ValTypeA;
using ElementB = typename Atom::ValTypeB;
using ElementC = typename Atom::ValTypeC;
using Shape_MNK = typename Atom::Shape_MNK;
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
return cute::MMA_Atom<RS>{};
}
template<template<cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<sA, sB>> const& tiled_mma) {
using Atom = cute::MMA_Atom<Primitive<sA, sB>>;
using ElementA = typename Atom::ValTypeA;
using ElementB = typename Atom::ValTypeB;
using ElementC = typename Atom::ValTypeC;
using Shape_MNK = typename Atom::Shape_MNK;
constexpr auto tA = cute::GMMA::Major::K;
constexpr auto tB = cute::GMMA::Major::K;
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
return cute::MMA_Atom<RS>{};
}
template<class Atom, class... Args>
CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA<Atom, Args...> const& tiled_mma) {
return cute::TiledMMA<decltype(convert_to_gmma_rs(Atom{})), Args...>{};
}
template<typename CLayout, typename AValueShape>
CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) {
return make_layout(
make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))),
make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c))));
}
template<class Layout, class Stages = _1>
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, make_layout(stages)));
}
template<class Element, class Accumulator, class OperandLayout_TV>
CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) {
Tensor operand = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv)));
Tensor operand_as_acc = make_tensor(operand.data(), acc.layout());
cute::copy(acc, operand_as_acc);
if constexpr (sizeof(Element) == 1) {
// 00 11 22 33 00 11 22 33 acc layout
// 00 00 11 11 22 22 33 33 operand layout
// BB AA AA BB AA BB BB AA conflict-free exchange pattern
// 16-bit exchange; so process two at a time potentially
int tid = threadIdx.x % 4;
auto values_u32 = recast<uint32_t>(operand);
CUTE_UNROLL
for (int n = 0; n < size<1>(values_u32); n++) {
CUTE_UNROLL
for (int k = 0; k < size<2>(values_u32); k++) {
CUTE_UNROLL
for (int ii = 0; ii < 8; ii += 4) {
uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k);
uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k);
// step A:
// t 1 v 0 -> t 0 v 1
// t 2 v 0 -> t 1 v 0
// t 0 v 1 -> t 2 v 0
// t 3 v 1 -> t 3 v 1
int v_to_send = tid == 1 || tid == 2 ? 0 : 1;
int v_to_recv = v_to_send;
int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF;
uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4);
// step B:
// t 0 v 0 -> t 0 v 0
// t 3 v 0 -> t 1 v 1
// t 1 v 1 -> t 2 v 1
// t 2 v 1 -> t 3 v 0
v_to_send = 1 - v_to_send;
v_to_recv = 1 - v_to_recv;
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF;
uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4);
values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410);
values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632);
}
}
}
}
return operand;
}
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,156 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "../collective/fmha_common.hpp"
namespace cutlass::fmha::collective {
template<class Element, class ElementAccumulator, class TileShape_WG>
struct FmhaFwdEpilogue {
static constexpr int Alignment = 16 / sizeof(Element);
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
cutlass::epilogue::TmaWarpSpecialized,
DefaultOperation
>::CollectiveOp;
struct Arguments {
Element* ptr_O;
cute::tuple<int, cute::_1, cute::tuple<int, int>> dO;
ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
};
struct Params {
ElementAccumulator* ptr_LSE;
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
typename CollectiveEpilogueTMA::Params epilogue_TMA;
};
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage;
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1,
make_shape(get<0>(problem_size), get<1>(problem_size)));
typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO};
return Params{
args.ptr_LSE, args.dLSE,
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace)
};
}
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
CUTLASS_DEVICE void operator()(
TileShape const& tile_shape, BlkCoord const& blk_coord,
ResultTuple const& result, TiledMma const& tiled_mma,
ProblemShape const& problem_size, Params const& params,
LoadPipeline epi_load_pipeline,
TensorStorage& epi_tensor_storage)
{
using X = Underscore;
auto acc = get<0>(result);
auto lse = get<1>(result);
auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
int seqlen_q = get<2>(problem_size);
int num_batch = get<0>(problem_size);
int num_heads = get<1>(problem_size);
// Epilogue for lse
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE),
make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)),
make_stride(_1{}, _0{}, get<1>(params.dLSE)));
Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{});
Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord));
Tensor tOgLSE = thr_mma.partition_C(gLSE);
Tensor cO = make_identity_tensor(take<0,2>(tile_shape));
Tensor tOcO = thr_mma.partition_C(cO);
if (get<1>(tOcO(_0{})) == 0) {
auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout()));
auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout()));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tOgLSE_mn); i++) {
if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) {
tOgLSE_mn(i, _0{}) = lse(i);
}
}
}
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _,
make_shape(get<0>(problem_size), get<1>(problem_size)));
CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage);
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
epilogue_tma.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)),
acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
epi_tensor_storage
);
epilogue_tma.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next
);
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,157 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "../collective/fmha_epilogue.hpp"
namespace cutlass::fmha::collective {
template<class Element, class ElementAccumulator, class TileShape_WG>
struct FmhaBwdEpilogueKV {
static constexpr int Alignment = 16 / sizeof(Element);
struct Arguments {
Element* ptr_K;
cute::tuple<int, int, int, cute::_1> dK;
Element* ptr_V;
cute::tuple<int, int, int, _1> dV;
};
//using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<cutlass::first, Element, ElementAccumulator, RoundStyle>,
cutlass::epilogue::fusion::Sm90AccFetch
>;
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
cutlass::epilogue::TmaWarpSpecialized,
DefaultOperation
>::CollectiveOp;
struct Params {
typename CollectiveEpilogueTMA::Params epilogue_K;
typename CollectiveEpilogueTMA::Params epilogue_V;
};
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2];
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
template<class ProblemShape>
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK),
make_stride(get<0>(args.dK), get<1>(args.dK)));
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV),
make_stride(get<0>(args.dV), get<1>(args.dV)));
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1,
make_shape(get<0>(problem_size), get<1>(problem_size)));
typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK};
typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV};
return Params{
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr),
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr)
};
}
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
CUTLASS_DEVICE void operator()(
TileShape const& tile_shape, BlkCoord const& blk_coord,
ResultTuple const& result, TiledMma const& tiled_mma,
ProblemShape const& problem_size, Params const& params,
LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage)
{
auto acc_k = get<0>(result);
auto acc_v = get<1>(result);
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _,
make_shape(get<0>(problem_size), get<1>(problem_size)));
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]};
CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]};
{
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
epilogue_k.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
epi_tensor_storage[0]
);
}
{
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
epilogue_v.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
epi_tensor_storage[1]
);
epilogue_k.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next
);
epilogue_v.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next
);
}
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,283 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
struct DefaultFusion {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return ceil_div(get<3>(problem_size), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return 0;
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
return;
}
};
struct ResidualFusion : DefaultFusion {
using Base = DefaultFusion;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return 1;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) >= get<3>(problem_size)) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct CausalFusion : DefaultFusion {
using Base = DefaultFusion;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<0>(pos) < get<1>(pos)) {
acc_qk(i) = -INFINITY;
}
}
}
};
template<class Base>
struct FusionBwdAdapter {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size));
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
auto index_base = index_qk(_0{});
auto index_shape = shape(index_qk);
auto index_stride = transform_leaf(stride(index_qk), [](auto elem) {
if constexpr (is_scaled_basis<decltype(elem)>::value) {
if constexpr(decltype(elem.mode() == _0{})::value) {
return ScaledBasis<decltype(elem.value()), 1>(elem.value());
} else {
return ScaledBasis<decltype(elem.value()), 0>(elem.value());
}
} else {
return elem;
}
});
auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride));
Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size);
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
bool is_contributing(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return true;
}
};
template<>
struct FusionBwdAdapter<CausalFusion> {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
return get<2>(problem_size) / get<0>(TileShape{});
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void before_softmax(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size
) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) < get<0>(pos)) {
acc_qk(i) = -INFINITY;
}
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
bool is_contributing(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size
) {
int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape);
int min_k = get<1>(blk_coord) * get<1>(tile_shape);
return min_k <= max_q;
}
};
} // namespace cutlass::fmha::collective

View File

@ -0,0 +1,278 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <class Kernel_>
class Universal {
public:
using Kernel = Kernel_;
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
/// Argument structure: User API
using Arguments = typename Kernel::Arguments;
/// Argument structure: Kernel API
using Params = typename Kernel::Params;
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (Kernel::can_implement(args)) {
return Status::kSuccess;
}
else {
return Status::kInvalid;
}
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
return workspace_bytes;
}
/// Computes the grid shape
static dim3
get_grid_shape(Params const& params) {
return Kernel::get_grid_shape(params);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
// Initialize the Params structure
params_ = Kernel::to_underlying_arguments(args, workspace);
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
params_ = Kernel::to_underlying_arguments(args, workspace);
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
cutlass::arch::synclog_setup();
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result && Status::kSuccess == launch_result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,299 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// common
#include "cutlass/cutlass.h"
#include "../device/device_universal.hpp"
#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp"
#include "../collective/fmha_fusion.hpp"
#include "../collective/fmha_epilogue_bwd.hpp"
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
#include "../kernel/fmha_kernel_bwd_convert.hpp"
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<class Element, class ElementAccumulator, class TileShape, class Fusion, class... Options>
class FmhaBwd {
public:
/// Argument structure: User API
struct Arguments {
cute::tuple<int, int, int, int, int> problem_size;
const Element* ptr_Q;
cute::tuple<int, int, int, cute::_1> stride_Q;
const Element* ptr_K;
cute::tuple<int, int, int, cute::_1> stride_K;
const Element* ptr_V;
cute::tuple<int, int, int, cute::_1> stride_V;
const Element* ptr_O;
cute::tuple<int, int, int, cute::_1> stride_O;
const ElementAccumulator* ptr_LSE;
cute::tuple<int, int, _1> stride_LSE;
const Element* ptr_dO;
cute::tuple<int, int, int, cute::_1> stride_dO;
Element* ptr_dQ;
cute::tuple<int, int, int, cute::_1> stride_dQ;
Element* ptr_dK;
cute::tuple<int, int, int, cute::_1> stride_dK;
Element* ptr_dV;
cute::tuple<int, int, int, cute::_1> stride_dV;
cutlass::KernelHardwareInfo hw_info;
};
using OperationSumOdO = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>>;
using OperationConvert = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>>;
using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized<
Element, ElementAccumulator, TileShape,
cutlass::fmha::collective::FusionBwdAdapter<Fusion>, Options...>;
using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV<Element, ElementAccumulator, typename Mainloop::TileShapePV>;
using Operation = cutlass::device::Universal<
cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<
Mainloop,
Epilogue,
cutlass::fmha::kernel::TileSchedulerBwdAdapter<cutlass::fmha::kernel::IndividualTileScheduler>, Options...>>;
struct Params {
OperationSumOdO op_sum_OdO;
Operation op;
OperationConvert op_convert;
ElementAccumulator* dQ_acc;
size_t dQ_acc_size;
};
private:
Params params_;
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) {
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_sum_OdO = make_stride(H*Q, Q, _1{});
return typename OperationSumOdO::Arguments {
args.problem_size,
args.ptr_O, args.stride_O,
args.ptr_dO, args.stride_dO,
dest, stride_sum_OdO
};
}
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{});
return typename OperationConvert::Arguments {
args.problem_size,
src, stride_src_dQ,
nullptr, stride_src_dQ,
nullptr, stride_src_dQ,
args.ptr_dQ, args.stride_dQ,
nullptr, args.stride_dK,
nullptr, args.stride_dV
};
}
static typename Operation::Arguments to_bwd_arguments(
Arguments const& args,
ElementAccumulator* sum_OdO = nullptr, cute::tuple<int, int, _1> const& stride_sum_OdO = {},
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, int, int, _1> const& stride_dQ = {}
) {
return typename Operation::Arguments{
args.problem_size,
{ args.ptr_Q, args.stride_Q,
args.ptr_K, args.stride_K,
args.ptr_V, args.stride_V,
args.ptr_dO, args.stride_dO,
args.ptr_LSE, args.stride_LSE,
sum_OdO, stride_sum_OdO,
dQ_acc, stride_dQ },
{ args.ptr_dK, args.stride_dK,
args.ptr_dV, args.stride_dV },
args.hw_info
};
}
public:
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
Status status = Status::kSuccess;
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = OperationConvert::can_implement(to_convert_arguments(args));
if (status != Status::kSuccess) {
return status;
}
status = Operation::can_implement(to_bwd_arguments(args));
if (status != Status::kSuccess) {
return status;
}
return status;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
return workspace_bytes;
}
/// Initializes state from arguments.
Status
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO);
auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
params_.op_convert.initialize(args_convert, nullptr, stream);
auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ);
params_.op.initialize(args_bwd, nullptr, stream);
return Status::kSuccess;
}
/// Initializes state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
auto [B, H, Q, K, D] = args.problem_size;
D = cutlass::round_up(D, 8); // Alignment
Q = cutlass::round_up(Q, 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, stream);
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
Status result = Status::kSuccess;
result = params.op_sum_OdO.run(stream);
if (result != Status::kSuccess) {
return result;
}
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
if (cuda_result != cudaSuccess) {
return Status::kErrorInternal;
}
result = params.op.run(stream);
if (result != Status::kSuccess) {
return result;
}
result = params.op_convert.run(stream);
if (result != Status::kSuccess) {
return result;
}
return Status::kSuccess;
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,158 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "../collective/fmha_collective_tma.hpp"
#include "../collective/fmha_collective_tma_warpspecialized.hpp"
#include "../collective/fmha_epilogue.hpp"
#include "../kernel/fmha_kernel_tma.hpp"
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::kernel {
template<
class Element_,
class ElementAccumulatorQK_,
class ElementAccumulatorPV_,
class TileShape_, // BlockQO, BlockKV, BlockHead
class LayoutQ_,
class LayoutK_,
class LayoutV_,
class Fusion,
class DispatchPolicy,
class... Options
>
struct FmhaBuilder;
template<
class Element,
class ElementAccumulator,
class TileShape, // BlockQO, BlockKV, BlockHead
class Fusion,
class... Options
>
struct FmhaBuilder<
Element,
ElementAccumulator,
ElementAccumulator,
TileShape,
cute::tuple<int, _1, cute::tuple<int, int>>,
cute::tuple<int, _1, cute::tuple<int, int>>,
cute::tuple<int, _1, cute::tuple<int, int>>,
Fusion,
cutlass::gemm::KernelTma,
Options...
> {
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma<Element, ElementAccumulator, TileShape, Fusion, Options...>;
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>;
using Kernel = cutlass::fmha::kernel::FmhaKernelTma<CollectiveMainloop, CollectiveEpilogue, Options...>;
};
template<
class Element,
class ElementAccumulatorQK,
class ElementAccumulatorPV,
class TileShape, // BlockQO, BlockKV, BlockHead
class LayoutQ,
class LayoutK,
class LayoutV,
class Fusion,
class... Options
>
struct FmhaBuilder<
Element,
ElementAccumulatorQK,
ElementAccumulatorPV,
TileShape,
LayoutQ,
LayoutK,
LayoutV,
Fusion,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Options...
> {
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized<
Element, ElementAccumulatorQK, ElementAccumulatorPV,
TileShape, LayoutQ, LayoutK, LayoutV,
Fusion, Options...>;
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>;
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<CollectiveMainloop, CollectiveEpilogue, TileScheduler, Options...>;
};
template<
class Element,
class ElementAccumulatorQK,
class ElementAccumulatorPV,
class TileShape, // BlockQO, BlockKV, BlockHead
class LayoutQ,
class LayoutK,
class LayoutV,
class Fusion,
class... Options
>
struct FmhaBuilder<
Element,
ElementAccumulatorQK,
ElementAccumulatorPV,
TileShape,
LayoutQ,
LayoutK,
LayoutV,
Fusion,
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
Options...
> {
using Kernel = typename FmhaBuilder<
Element, ElementAccumulatorQK, ElementAccumulatorPV,
TileShape,
LayoutQ, LayoutK, LayoutV,
Fusion,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Options...,
Option<Tag::kIsPersistent, true_type>,
Option<Tag::kLoadsQSeparately, true_type>
>::Kernel;
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,143 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class Element, class ElementAccumulator>
struct FmhaKernelBwdConvert {
struct Arguments {
tuple<int, int, int, int, int> problem_size;
const ElementAccumulator* ptr_src_dQ;
tuple<int, int, int, _1> stride_src_dQ;
const ElementAccumulator* ptr_src_dK;
tuple<int, int, int, _1> stride_src_dK;
const ElementAccumulator* ptr_src_dV;
tuple<int, int, int, _1> stride_src_dV;
Element* ptr_dest_dQ;
tuple<int, int, int, _1> stride_dest_dQ;
Element* ptr_dest_dK;
tuple<int, int, int, _1> stride_dest_dK;
Element* ptr_dest_dV;
tuple<int, int, int, _1> stride_dest_dV;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static const int kBlockSeq = 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kNumThreadsD = 16;
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 4;
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
static bool can_implement(Arguments const& args) {
return get<4>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
template<class StrideSrc, class StrideDest>
CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y;
auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y;
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
if (idx_s >= count) continue;
auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src);
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
ElementAccumulator value_src[kElementsPerLoad];
Element value_dest[kElementsPerLoad];
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAccumulator> * kElementsPerLoad>;
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
value_dest[v] = value_src[v];
}
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
}
}
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size));
}
if (params.ptr_src_dK != nullptr) {
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size));
}
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size));
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,134 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<class Element, class ElementAccumulator>
struct FmhaKernelBwdSumOdO {
struct Arguments {
cute::tuple<int, int, int, int, int> problem_size;
const Element* ptr_O;
cute::tuple<int, int, int, cute::_1> stride_O;
const Element* ptr_dO;
cute::tuple<int, int, int, cute::_1> stride_dO;
ElementAccumulator* ptr_sum_OdO;
cute::tuple<int, int, _1> stride_sum_OdO;
};
using Params = Arguments;
using ClusterShape = Shape<_1, _1, _1>;
static constexpr int SharedStorageSize = 0;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = 128;
using ArchTag = cutlass::arch::Sm90;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static const int kBlockQ = 16;
static const int kNumThreadsD = 8;
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
static const int kElementsPerLoad = 2;
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
static bool can_implement(Arguments const& args) {
return get<4>(args.problem_size) % kElementsPerLoad == 0;
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size));
return grid;
}
static dim3 get_block_shape() {
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return args;
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO);
CUTLASS_PRAGMA_UNROLL
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
if (idx_q >= get<2>(params.problem_size)) continue;
ElementAccumulator acc = 0;
auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O);
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO);
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO);
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
Element value_O[kElementsPerLoad];
Element value_dO[kElementsPerLoad];
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
for (int v = 0; v < kElementsPerLoad; v++) {
acc += value_O[v] * value_dO[v];
}
}
for (int i = 1; i < kNumThreadsD; i *= 2) {
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
}
if (threadIdx.x == 0) {
*ptr_sum_OdO_bhq = acc;
}
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,222 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/arch/arch.h"
#include "../kernel/fmha_tile_scheduler.hpp"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::kernel {
template<
class CollectiveMainloop,
class CollectiveEpilogue,
class... Options
>
struct FmhaKernelTma {
// Options
static constexpr int kBlocksPerSM = find_option_t<Tag::kBlocksPerSM, Int<2>, Options...>::value;
using Element = typename CollectiveMainloop::Element;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using TileScheduler = IndividualTileScheduler;
using StagesQ = typename CollectiveMainloop::StagesQ;
using Stages = typename CollectiveMainloop::Stages;
using TileShape = typename CollectiveMainloop::TileShape;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ;
using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK;
struct SharedStorage {
union {
typename CollectiveMainloop::SharedStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue;
};
using PipelineStorage = typename MainloopPipeline::SharedStorage;
using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage;
alignas(16) PipelineStorage pipeline_storage;
alignas(16) PipelineStorageQ pipeline_storage_q;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) EpiLoadPipelineStorage epi_load;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
using ProblemShape = cute::tuple<int, int, int, int, int>;
struct Arguments {
ProblemShape problem_size;
typename CollectiveMainloop::Arguments mainloop;
typename CollectiveEpilogue::Arguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_size;
typename CollectiveMainloop::Params mainloop;
typename CollectiveEpilogue::Params epilogue;
typename TileScheduler::Params tile_scheduler;
};
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
static const int MinBlocksPerMultiprocessor = kBlocksPerSM;
static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
using ArchTag = cutlass::arch::Sm90;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static bool can_implement(Arguments const& args) {
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return Params{
args.problem_size,
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
};
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
TileScheduler tile_scheduler{params.tile_scheduler};
// Shared memory.
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
int thread_idx = int(threadIdx.x);
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
int warp_idx = cutlass::canonical_warp_idx_sync();
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
int lane_predicate = cute::elect_one_sync();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q
pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer;
pipeline_params_q.is_leader = warp_group_thread_idx == 0;
pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup;
MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{});
MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{});
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer;
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineState smem_pipe_read;
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineStateQ smem_pipe_read_q;
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
// and to finish smem init
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
else {
__syncthreads();
}
auto blk_coord = tile_scheduler.get_block_coord();
CollectiveMainloop collective_mainloop;
auto result = collective_mainloop.compute(
block_rank_in_cluster,
blk_coord, params.mainloop, params.problem_size,
pipeline, smem_pipe_read, smem_pipe_write,
pipeline_q, smem_pipe_read_q, smem_pipe_write_q,
storage.mainloop
);
CollectiveEpilogue epilogue;
epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord,
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.epilogue);
}
};
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,418 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/arch/arch.h"
#include "../kernel/fmha_options.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class CollectiveMainloop,
class CollectiveEpilogue,
class TileScheduler,
class... Options
>
struct FmhaKernelTmaWarpSpecialized {
// Options
static constexpr bool kIsEpilogueLocked = find_option_t<Tag::kIsEpilogueLocked, false_type, Options...>::value;
static constexpr bool kLoadsQSeparately = find_option_t<Tag::kLoadsQSeparately, false_type, Options...>::value;
static const int NumLoadWarpGroups = 1;
static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups;
using TileShape = typename CollectiveMainloop::TileShape;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ;
using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline;
using MainloopPipelineReducer = cutlass::PipelineAsync<2>;
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
StagesPerMathWarpGroup, NumMmaWarpGroups>;
struct TensorStorageStruct {
typename CollectiveMainloop::SharedStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
};
union TensorStorageUnion {
typename CollectiveMainloop::SharedStorage mainloop;
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
};
using TensorStorage = std::conditional_t<CollectiveMainloop::kIsPersistent, TensorStorageStruct, TensorStorageUnion>;
struct SharedStorage {
TensorStorage tensors;
using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage;
using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage;
using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage;
alignas(16) PipelineStorageInner pipeline_storage_inner;
alignas(16) PipelineStorageOuter pipeline_storage_outer;
alignas(16) PipelineStorageReducer pipeline_storage_reducer;
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) EpiLoadPipelineStorage epi_load;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
using ProblemShape = cute::tuple<int, int, int, int, int>;
struct Arguments {
ProblemShape problem_size;
typename CollectiveMainloop::Arguments mainloop;
typename CollectiveEpilogue::Arguments epilogue;
KernelHardwareInfo hw_info;
};
struct Params {
ProblemShape problem_size;
typename CollectiveMainloop::Params mainloop;
typename CollectiveEpilogue::Params epilogue;
typename TileScheduler::Params tile_scheduler;
};
using PipelineParamsInner = typename MainloopPipelineInner::Params;
using PipelineStateInner = typename cutlass::PipelineState<MainloopPipelineInner::Stages>;
using PipelineParamsOuter = typename MainloopPipelineOuter::Params;
using PipelineStateOuter = typename cutlass::PipelineState<MainloopPipelineOuter::Stages>;
using PipelineParamsReducer = typename MainloopPipelineReducer::Params;
using PipelineStateReducer = typename cutlass::PipelineState<MainloopPipelineReducer::Stages>;
static const int MinBlocksPerMultiprocessor = 1;
static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup;
using ArchTag = cutlass::arch::Sm90;
static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8;
static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup;
static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8;
static size_t get_workspace_size(Arguments const& args) { return 0; }
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
return cutlass::Status::kSuccess;
}
static bool can_implement(Arguments const& args) {
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return Params{
args.problem_size,
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
};
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
enum class WarpGroupRole {
Producer = 0,
Consumer0 = 1,
Consumer1 = 2,
Consumer2 = 3,
Consumer3 = 4,
};
enum class ProducerWarpRole {
LoadKV = 1,
Reducer = 0,
MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it)
MainloopEpilogue = 3,
};
static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV;
TileScheduler tile_scheduler{params.tile_scheduler};
// Shared memory.
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
int lane_idx = cutlass::canonical_lane_idx();
int warp_idx = cutlass::canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup;
int warp_group_idx = cutlass::canonical_warp_group_idx();
auto warp_group_role = WarpGroupRole(warp_group_idx);
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0;
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
}
PipelineParamsOuter pipeline_params_outer;
pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes;
pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ);
pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup;
PipelineParamsInner pipeline_params_inner;
pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes;
pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV);
pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
PipelineParamsReducer pipeline_params_reducer;
pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp;
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) {
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) {
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) {
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1 ||
warp_group_role == WarpGroupRole::Consumer2 ||
warp_group_role == WarpGroupRole::Consumer3
) {
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer;
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer;
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer;
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{});
MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{});
MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer);
// State variables used for iterating the circular buffer
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
PipelineStateInner smem_pipe_read_inner;
PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state<MainloopPipelineInner>();
PipelineStateOuter smem_pipe_read_outer;
PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state<MainloopPipelineOuter>();
PipelineStateReducer smem_pipe_read_reducer;
PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state<MainloopPipelineReducer>();
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id = consumer_warp_group_idx;
params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier);
// Epilogue Load pipeline
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
if constexpr (kLoadsQSeparately) {
if ((warp_idx == 0) && lane_predicate) {
storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp);
}
cutlass::arch::fence_barrier_init();
}
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer blocks in the Cluster
// and to finish smem init
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
else {
__syncthreads();
}
CollectiveMainloop collective_mainloop;
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
if (producer_warp_role == ProducerWarpRole::LoadKV) {
bool do_barrier = kLoadsQSeparately;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
collective_mainloop.template load_kv_maybe_q<!kLoadsQSeparately>(
block_rank_in_cluster,
blk_coord, params.mainloop, params.problem_size,
pipeline_inner, smem_pipe_write_inner,
pipeline_outer, smem_pipe_write_outer,
storage.tensors.mainloop,
storage.load_warp_barrier, do_barrier
);
do_barrier = false;
}
}
else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) {
bool do_barrier = true;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
collective_mainloop.load_maybe_q(
blk_coord, params.mainloop, params.problem_size,
pipeline_outer, smem_pipe_write_outer,
storage.tensors.mainloop,
storage.load_warp_barrier, do_barrier
);
do_barrier = false;
}
} else if (producer_warp_role == ProducerWarpRole::Reducer) {
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
collective_mainloop.reduce(
blk_coord, params.mainloop, params.problem_size,
pipeline_reducer, smem_pipe_read_reducer,
storage.tensors.mainloop
);
}
}
}
else if (
warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1 ||
warp_group_role == WarpGroupRole::Consumer2 ||
warp_group_role == WarpGroupRole::Consumer3
) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto wg_coord = blk_coord;
constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads;
if (warp_group_role == WarpGroupRole::Consumer0) {
smem_pipe_read_outer.advance(0 * kOuterLoads);
}
else if (warp_group_role == WarpGroupRole::Consumer1) {
smem_pipe_read_outer.advance(1 * kOuterLoads);
}
else if (warp_group_role == WarpGroupRole::Consumer2) {
smem_pipe_read_outer.advance(2 * kOuterLoads);
}
else if (warp_group_role == WarpGroupRole::Consumer3) {
smem_pipe_read_outer.advance(3 * kOuterLoads);
}
constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1;
auto& wg_block = get<wg_dim>(wg_coord);
if (warp_group_role == WarpGroupRole::Consumer0) {
wg_block = NumMmaWarpGroups * wg_block + 0;
}
else if (warp_group_role == WarpGroupRole::Consumer1) {
wg_block = NumMmaWarpGroups * wg_block + 1;
}
else if (warp_group_role == WarpGroupRole::Consumer2) {
wg_block = NumMmaWarpGroups * wg_block + 2;
}
else if (warp_group_role == WarpGroupRole::Consumer3) {
wg_block = NumMmaWarpGroups * wg_block + 3;
}
auto result = collective_mainloop.compute(
blk_coord, wg_coord,
params.mainloop, params.problem_size,
pipeline_inner, smem_pipe_read_inner,
pipeline_outer, smem_pipe_read_outer,
pipeline_reducer, smem_pipe_write_reducer,
storage.tensors.mainloop,
math_wg_order_barrier
);
if (warp_group_role == WarpGroupRole::Consumer0) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0));
}
if constexpr (NumMmaWarpGroups >= 2) {
if (warp_group_role == WarpGroupRole::Consumer1) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1));
}
}
if constexpr (NumMmaWarpGroups >= 3) {
if (warp_group_role == WarpGroupRole::Consumer2) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2));
}
}
if constexpr (NumMmaWarpGroups >= 4) {
if (warp_group_role == WarpGroupRole::Consumer3) {
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3));
}
}
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait();
CollectiveEpilogue epilogue;
epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord,
result, typename CollectiveMainloop::TiledMmaPV{},
params.problem_size, params.epilogue,
epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]);
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
}
}
}
};
} // namespace cutlass::fmha::kernel

View File

@ -1,5 +1,5 @@
/*************************************************************************************************** /***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause * SPDX-License-Identifier: BSD-3-Clause
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
@ -28,51 +28,56 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *
**************************************************************************************************/ **************************************************************************************************/
#pragma once #pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE #include "cutlass/cutlass.h"
#include <cute/numeric/integral_constant.hpp> // cute::true_type
namespace cute namespace cutlass::fmha::kernel {
{
template <class T> template<auto kTag, typename Default, typename... Options>
struct ConstantTensor struct find_option;
{
template <class... Coords>
CUTE_HOST_DEVICE constexpr
T const&
operator()(Coords const&...) const {
return val_;
}
T val_; template<auto kTag, typename Default>
struct find_option<kTag, Default> {
using option_value = Default;
}; };
struct TrivialPredTensor template<auto kTag, typename Default, typename Option, typename... Options>
{ struct find_option<kTag, Default, Option, Options...> :
template <class... Coords> std::conditional_t<
CUTE_HOST_DEVICE constexpr Option::tag == kTag,
true_type Option,
operator()(Coords const&...) const { find_option<kTag, Default, Options...>
return {}; >
} {};
template<auto kTag, typename Default, typename... Options>
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
enum class Tag {
kIsPersistent,
kNumMmaWarpGroups,
kLoadsQSeparately,
kIsMainloopLocked,
kIsEpilogueLocked,
kStagesQ,
kStagesKV,
kEpilogueKind,
kBlocksPerSM,
kClusterM,
kAccQK
}; };
template <class Fn> template<auto kTag, class Value>
struct FunctionPredTensor struct Option {
{ static constexpr auto tag = kTag;
CUTE_HOST_DEVICE constexpr using option_value = Value;
FunctionPredTensor(Fn const& fn) : fn_(fn) {}
template <class... Coords>
CUTE_HOST_DEVICE constexpr
auto
operator()(Coords const&... coords) const {
return fn_(coords...);
}
Fn const& fn_;
}; };
} // end namespace cute } // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,204 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct IndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
IndividualTileScheduler(Params const&) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size));
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
}
CUTLASS_DEVICE
IndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct PersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_h;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size);
return Params {
num_blocks,
{ num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) },
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, bidh;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_h(block_decode, bidh, block_decode);
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
}
CUTLASS_DEVICE
PersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
template<typename Base>
struct TileSchedulerBwdAdapter {
using Params = typename Base::Params;
Base base_;
CUTLASS_DEVICE
TileSchedulerBwdAdapter(Params const& params) : base_(params) {}
template<class ProblemSize, class ClusterShape, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, TileShape const& tile_shape)
{
using namespace cute;
return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape));
}
static dim3 get_grid_shape(Params const& params) {
return Base::get_grid_shape(params);
}
CUTLASS_DEVICE
bool is_valid() {
return base_.is_valid();
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return select<1,0,2>(base_.get_block_coord());
}
CUTLASS_DEVICE
TileSchedulerBwdAdapter& operator++() {
++base_;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@ -0,0 +1,357 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
class TensorDQ, /* class TensorDK, class TensorDV, */
class Fusion
>
void __global__ fmha_bwd_reference_dQ_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_K] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
}
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
}
mDQ(idx_Q, idx_D, idx_L) = acc;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/* class TensorDQ, */ class TensorDK, /* class TensorDV, */
class Fusion
>
void __global__ fmha_bwd_reference_dK_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
ElementAccumulator acc_dov = 0;
ElementAccumulator acc_doo = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
}
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
}
mDK(idx_K, idx_D, idx_L) = acc;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/* class TensorDQ, class TensorDK, */ class TensorDV,
class Fusion
>
void __global__ fmha_bwd_reference_dV_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
ElementAccumulator acc_qk = 0;
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
}
auto id = make_identity_tensor(make_shape(1, 1));
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc_qk;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
acc_qk = frag(0);
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)));
}
__syncthreads();
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L);
}
mDV(idx_K, idx_D, idx_L) = acc;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/
class Fusion
>
void fmha_bwd_reference_dQ(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_bwd_reference_dQ_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDQ, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/
class Fusion
>
void fmha_bwd_reference_dK(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_bwd_reference_dK_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDK, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
/** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/
class Fusion
>
void fmha_bwd_reference_dV(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
/** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
dim3 block(256);
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_bwd_reference_dV_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDV, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ, class TensorK, class TensorV,
class TensorO, class TensorLSE, class TensorDO,
class TensorDQ, class TensorDK, class TensorDV,
class Fusion
>
void fmha_bwd_reference(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE, TensorDO mDO,
TensorDQ mDQ, TensorDK mDK, TensorDV mDV,
Fusion fusion
) {
fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,156 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ,
class TensorK,
class TensorV,
class TensorO,
class TensorLSE,
class Fusion
>
void __global__ fmha_reference_kernel(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE,
Fusion fusion
) {
using namespace cute;
using Element = typename TensorO::value_type;
using ElementAccumulator = typename TensorLSE::value_type;
extern __shared__ char mS_mem[];
Element* mS = reinterpret_cast<Element*>(mS_mem);
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
auto id = make_identity_tensor(make_shape(1, 1));
for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) {
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) {
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) {
acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L);
}
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
frag(0) = acc;
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
mS[idx_K] = static_cast<Element>(frag(0) * softmax_scale);
}
__syncthreads();
ElementAccumulator maxS = -std::numeric_limits<ElementAccumulator>::infinity();
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
maxS = std::max<ElementAccumulator>(maxS, mS[idx_K]);
}
if (maxS == -std::numeric_limits<ElementAccumulator>::infinity()) maxS = 0;
__syncthreads();
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
mS[idx_K] = static_cast<Element>(exp(mS[idx_K] - maxS));
}
__syncthreads();
ElementAccumulator sum = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
sum += mS[idx_K];
}
Element scale = static_cast<Element>(1.0 / sum);
for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) {
ElementAccumulator acc = 0;
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale;
}
mO(idx_Q, idx_D, idx_L) = static_cast<Element>(acc);
}
if (threadIdx.x == 0) {
mLSE(idx_Q, idx_L) = log(sum) + maxS;
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ProblemShape,
class TensorQ,
class TensorK,
class TensorV,
class TensorO,
class TensorLSE,
class Fusion
>
void fmha_reference(
ProblemShape problem_shape,
TensorQ mQ, TensorK mK, TensorV mV,
TensorO mO, TensorLSE mLSE,
Fusion fusion
) {
using namespace cute;
dim3 grid(size<0>(mO), size<2>(mO), 1);
dim3 block(256);
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
if (shared_mem >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
auto result = cudaFuncSetAttribute(
fmha_reference_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, Fusion>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_mem);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return;
}
}
fmha_reference_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,129 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cmath>
#include "cutlass/util/device_memory.h"
template<typename Element>
__global__ void reference_abs_diff_kernel(
Element* data, Element* data_ref, size_t count,
double* max_diff, double* sum_diff,
bool print_diff
) {
double thread_max_diff = 0;
double thread_sum_diff = 0;
__shared__ double block_max_diff;
__shared__ double block_sum_diff;
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
double diff = fabs(data[i] - data_ref[i]);
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
thread_max_diff = fmax(diff, thread_max_diff);
thread_sum_diff += diff;
}
for (int i = 0; i < blockDim.x; i++) {
if (i == threadIdx.x) {
if (i == 0) {
block_max_diff = thread_max_diff;
block_sum_diff = thread_sum_diff;
} else {
block_max_diff = fmax(block_max_diff, thread_max_diff);
block_sum_diff += thread_sum_diff;
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(sum_diff, block_sum_diff);
for (;;) {
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
double prev_diff = reinterpret_cast<double const&>(prev);
double new_max_diff = fmax(block_max_diff, prev_diff);
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
if (found == prev) break;
}
}
}
template<typename Element>
void reference_abs_diff(
cutlass::DeviceAllocation<Element> const& data,
cutlass::DeviceAllocation<Element> const& data_ref,
double& max_diff, double& mean_diff
) {
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
cutlass::DeviceAllocation<double> result;
result.reset(2);
assert(data.size() == data_ref.size());
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
if (err != cudaSuccess) {
std::cerr << "Memset failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
dim3 block(256, 1, 1);
dim3 grid(1024, 1, 1);
reference_abs_diff_kernel<<<block, grid>>>(
data.get(), data_ref.get(), data.size(),
result.get(), result.get() + 1, kPrintDiff);
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
std::cerr << "Difference kernel failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
double result_host[2];
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
if (err != cudaSuccess) {
std::cerr << "Copy failed. Last CUDA error: "
<< cudaGetErrorString(err) << std::endl;
max_diff = mean_diff = 1e20;
return;
}
max_diff = result_host[0];
mean_diff = result_host[1] / static_cast<double>(data.size());
}

View File

@ -163,6 +163,7 @@ foreach(EXAMPLE
82_blackwell_distributed_gemm 82_blackwell_distributed_gemm
83_blackwell_sparse_gemm 83_blackwell_sparse_gemm
84_blackwell_narrow_precision_sparse_gemm 84_blackwell_narrow_precision_sparse_gemm
88_hopper_fmha
) )
add_subdirectory(${EXAMPLE}) add_subdirectory(${EXAMPLE})

View File

@ -55,3 +55,7 @@ cutlass_example_add_executable(
tiled_copy.cu tiled_copy.cu
) )
cutlass_example_add_executable(
cute_tutorial_tiled_copy_if
tiled_copy_if.cu
)

View File

@ -0,0 +1,297 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
// This example extends `tiled_copy` using predicate tensors to guard memory accesses performed
// by `cute::copy_if()`. This enables tensors to have shapes that are not integer multiples of
// block sizes.
//
// This is accomplished by instantiating a tensor of coordinates which correspond to tensor elements
// to be accessed and then computing a predicate tensor which masks accesses. The example demonstrates
// how constructing of an identity tensor containing coordinates and a predicate tensor containing
// mask bits can be implemented using the same CuTe operations used to tile the tensors in
// Global Memory.
//
// This example implements two variants:
// - copy_if_kernel() uses `cute::local_partition()` to construct each thread's slice
// - copy_if_kernel_vectorized() uses `make_tiled_copy() to implement vectorized memory accesses.
//
// The tensor shapes and strides must be divisible by the shape of the vector access.
//
/// Simple copy kernel.
//
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
template <class TensorS, class TensorD, class BlockShape, class ThreadLayout>
__global__ void copy_if_kernel(TensorS S, TensorD D, BlockShape block_shape, ThreadLayout)
{
using namespace cute;
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
auto shape_S = shape(S);
Tensor C = make_identity_tensor(shape_S);
// Construct a predicate tensor which compares the coordinates with the original shape
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
// Tile the input tensor into blocks
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
// Construct a partitioning of the tile among threads with the given thread arrangement.
// Concept: Tensor ThrLayout ThrIndex
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
Tensor thr_tile_P = local_partition(tile_P, ThreadLayout{}, threadIdx.x);
// Copy from GMEM to GMEM using `thr_tile_P` to guard accesses.
copy_if(thr_tile_P, thr_tile_S, thr_tile_D);
}
/// Vectorized copy kernel.
///
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
/// has the precondition that pointers are aligned to the vector size.
///
template <class TensorS, class TensorD, class BlockShape, class Tiled_Copy>
__global__ void copy_if_kernel_vectorized(TensorS S, TensorD D, BlockShape block_shape, Tiled_Copy tiled_copy)
{
using namespace cute;
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
auto shape_S = shape(S);
Tensor C = make_identity_tensor(shape_S);
// Construct a predicate tensor which compares the coordinates with the original shape
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
// Tile the input tensor into blocks
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
//
// Construct a Tensor corresponding to each thread's slice.
//
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CPY, CPY_M, CPY_N)
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CPY, CPY_M, CPY_N)
Tensor thr_tile_P = thr_copy.partition_S(tile_P); // (CPY, CPY_M, CPY_N)
#if 0
// Copy from GMEM to GMEM
copy_if(tiled_copy, thr_tile_P, thr_tile_S, thr_tile_D);
#else
// make_fragment_like() constructs a tensor in RMEM with the same shape as thr_tile_S.
Tensor frag = make_fragment_like(thr_tile_S);
// Copy from GMEM to RMEM and from RMEM to GMEM
copy_if(tiled_copy, thr_tile_P, thr_tile_S, frag);
copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
#endif
}
/// Main function
int main(int argc, char** argv)
{
//
// Given a 2D shape, perform an efficient copy
//
using namespace cute;
using Element = float;
// Define a tensor shape with dynamic extents (m, n)
auto tensor_shape = make_shape(528, 300);
thrust::host_vector<Element> h_S(size(tensor_shape));
thrust::host_vector<Element> h_D(size(tensor_shape));
//
// Initialize
//
for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}
thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;
thrust::device_vector<Element> d_Zero = h_D;
//
// Make tensors
//
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
//
// Partition
//
// Define a statically sized block (M, N).
//
// Note, by convention, capital letters are used to represent static modes.
auto block_shape = make_shape(Int<128>{}, Int<64>{});
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
// shape, and modes (m', n') correspond to the number of tiles.
//
// These will be used to determine the CUDA kernel grid dimensinos.
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
// Describes the layout of threads which is then replicated to tile 'block_shape.'
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (ThrM, ThrN)
//
// Determine grid and block dimensions
//
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(thr_layout));
//
// Launch the kernel
//
// copy_if()
copy_if_kernel<<< gridDim, blockDim >>>(
tensor_S,
tensor_D,
block_shape,
thr_layout);
cudaError result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
return -1;
}
h_D = d_D;
//
// Verification
//
auto verify = [](thrust::host_vector<Element> const &S, thrust::host_vector<Element> const &D){
int32_t errors = 0;
int32_t const kErrorLimit = 10;
if (S.size() != D.size()) {
return 1;
}
for (size_t i = 0; i < D.size(); ++i) {
if (S[i] != D[i]) {
std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << i << "]: " << D[i] << std::endl;
if (++errors >= kErrorLimit) {
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
return errors;
}
}
}
return errors;
};
if (verify(h_D, h_S)) {
return -1;
} else {
std::cout << "Success." << std::endl;
}
thrust::copy(d_Zero.begin(), d_Zero.end(), d_D.begin());
// Construct a TiledCopy with a specific access pattern.
// This version uses a
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
// (2) Layout-of-Values that each thread will access.
// Value arrangement per thread
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
// Define `AccessType` which controls the size of the actual memory access instruction.
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
using Atom = Copy_Atom<CopyOp, Element>;
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative value layouts are also possible, though incompatible layouts
// will result in compile time errors.
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
thr_layout, // thread layout (e.g. 32x4 Col-Major)
val_layout); // value layout (e.g. 4x1)
// copy_if() with vectorization
copy_if_kernel_vectorized<<< gridDim, blockDim >>>(
tensor_S,
tensor_D,
block_shape,
tiled_copy);
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
return -1;
}
h_D = d_D;
if (verify(h_D, h_S)) {
return -1;
} else {
std::cout << "Success." << std::endl;
}
return 0;
}

View File

@ -0,0 +1,200 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
import cutlass
import torch
import numpy as np
from cutlass.cute.runtime import from_dlpack
"""
A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL.
This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL.
It shows various ways to allocate different data structures in shared memory:
1. Struct allocation with natural and strict alignment
2. Raw memory block allocation with custom alignment
3. Array allocation with automatic alignment
4. Tensor allocation with layout specification
The example includes:
- Shared storage struct with mixed alignment requirements
- Memory allocation patterns for different data types
- Tensor operations on allocated memory
To run this example:
.. code-block:: bash
python examples/ampere/smem_allocator.py
The example will allocate shared memory, perform tensor operations, and verify the results.
"""
@cute.struct
class complex:
real: cutlass.Float32
imag: cutlass.Float32
# SharedStorage size is 512, alignment is 128
@cute.struct
class SharedStorage:
# struct elements with natural alignment
a: cute.struct.MemRange[cutlass.Float32, 32] # array
b: cutlass.Int64 # saclar
c: complex # nested struct
# struct elements with strict alignment
x: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, 32],
128,
]
y: cute.struct.Align[cutlass.Int32, 8]
z: cute.struct.Align[complex, 16]
@cute.kernel
def kernel(
const_a: cutlass.Constexpr,
dst_a: cute.Tensor,
const_b: cutlass.Constexpr,
dst_b: cute.Tensor,
const_c: cutlass.Constexpr,
dst_c: cute.Tensor,
):
# Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize
# Note: alignment of inital allocator base ptr is 1024
allocator = cutlass.utils.SmemAllocator()
# base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory)
# -- Allocate a struct --
# Note: when specified alignment, max(alignment, alignof(struct)) will be applied
# reserves the section of struct in smem, elements in the struct can be accessed by ptr
struct_in_smem = allocator.allocate(SharedStorage)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct)
# -- Allocate a block of memory --
# reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr
section_in_smem = allocator.allocate(64, byte_alignment=128)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section)
# -- Allocate an array --
# reserves an int64 array of size 14 in smem, returns the array base ptr
array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array)
# -- Allocate a tensor --
# Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor
# Note: iterator swizzle with swizzle layout is currently not supported
layout = cute.make_layout((16, 2))
tensor_in_smem = allocator.allocate_tensor(
element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None
)
# base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor)
# ptr<f16, smem, align<1024>>
# ptr<i64, smem, align<128>>
# ptr<f32, smem, align<8>>
print(struct_in_smem.a.data_ptr())
print(struct_in_smem.b)
print(struct_in_smem.c.real)
# ptr<i8, smem, align<512>>
print(section_in_smem)
# ptr<i64, smem, align<64>>
print(array_in_smem)
# tensor<ptr<f16, smem, align<32>> o (16,4):(1,16)>
print(tensor_in_smem)
# fill MemRange tensor in struct and copy to dst
a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4)))
a_tensor.fill(const_a)
cute.printf("cute.struct.MemRange: {}", a_tensor)
dst_a.store(a_tensor.load())
# convert block of smem to fill tensor and copy to dst
layout = cute.make_layout((8, 2))
sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32)
sec_tensor = cute.make_tensor(sec_ptr, layout)
sec_tensor.fill(const_b)
cute.printf("block of memory: {}", sec_tensor)
dst_b.store(sec_tensor.load())
# fill allocated tensor in smem and copy to dst
tensor_in_smem.fill(const_c)
cute.printf("tensor in smem: {}", tensor_in_smem)
dst_c.store(tensor_in_smem.load())
@cute.jit
def run_allocation_kernel(
const_a: cutlass.Constexpr,
dst_a: cute.Tensor,
const_b: cutlass.Constexpr,
dst_b: cute.Tensor,
const_c: cutlass.Constexpr,
dst_c: cute.Tensor,
):
# additional size for the example, 64(section) + 112(array) + 128(tensor) < 384
addtional_bytes = 384
# Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes
kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch(
grid=(1, 1, 1),
block=(1, 1, 1),
smem=SharedStorage.size_in_bytes() + addtional_bytes,
)
def veify_allocation_kernel(const_a, const_b, const_c):
dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda")
dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda")
dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda")
run_allocation_kernel(
const_a,
from_dlpack(dst_a),
const_b,
from_dlpack(dst_b),
const_c,
from_dlpack(dst_c),
)
np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0])
np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0])
np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0])
if __name__ == "__main__":
# prepare cuda context
cutlass.cuda.initialize_cuda_context()
# An example for shared memory allocation
const_a = 0.5
const_b = 1.0
const_c = 2.0
veify_allocation_kernel(const_a, const_b, const_c)

View File

@ -0,0 +1,51 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.15)
project(tensor)
# Find Python
find_package(Python COMPONENTS Interpreter Development REQUIRED)
# Get Python site-packages directory using Python
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])"
OUTPUT_VARIABLE Python_SITE_PACKAGES
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}")
# Add nanobind path to CMAKE_PREFIX_PATH
list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake)
# Find nanobind
find_package(nanobind REQUIRED)
# Add the module
nanobind_add_module(tensor tensor.cpp)

View File

@ -0,0 +1,305 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
The C-structure is defined as:
.. code-block:: c
struct Tensor {
void *ptr; // Pointer to tensor data
int32_t shape[3]; // Tensor dimensions
int32_t strides[3]; // Memory strides for each dimension
};
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
shape, and strides, enabling efficient data passing between different language boundaries.
.. note::
Future development may include automated code generation flows.
"""
import cutlass
import cutlass.cute as cute
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
import cutlass._mlir.extras.types as T
class ExampleTensorValue(ir.Value):
"""A wrapper class for tensor values in MLIR.
This class extends ir.Value to provide convenient access to tensor data pointer,
shape, and strides through MLIR operations.
:type: ir.Value
"""
def __init__(self, v):
"""Initialize a new TensorValue.
:param v: The underlying MLIR value to wrap
:type v: ir.Value
"""
super().__init__(v)
@property
def data_ptr(self, *, loc=None, ip=None):
"""Get the data pointer from the tensor value.
Extracts the data pointer (first field) from the LLVM struct value.
:param loc: Optional location information for MLIR operations
:type loc: Optional[ir.Location]
:param ip: Optional insertion point for MLIR operations
:type ip: Optional[ir.InsertionPoint]
:return: An integer value representing the data pointer
:rtype: ir.Value
"""
# Extract the data pointer from the LLVM struct value
# The data pointer is the first field (index 0) in the struct
# Use llvm.extractvalue to get the pointer field from the struct
ptr_val = llvm.extractvalue(
llvm.PointerType.get(),
self,
[0], # Extract the first field (index 0)
loc=loc,
ip=ip,
)
return cute.make_ptr(cutlass.Float32, ptr_val)
@property
def shape(self):
"""Get the shape of the tensor.
Extracts the shape (second field) from the LLVM struct value.
:return: A tuple of integers representing the tensor dimensions
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the shape field from the LLVM struct value
# The shape is the second field (index 1) in the struct
shape_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[1], # Extract the second field (index 1)
)
# Extract each dimension from the shape struct
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
@property
def stride(self):
"""Get the strides of the tensor.
Extracts the strides (third field) from the LLVM struct value.
:return: A tuple of integers representing the tensor strides
:rtype: tuple[ir.Value, ...]
"""
i32_type = ir.IntegerType.get_signless(32)
# Extract the strides field from the LLVM struct value
# The strides are the third field (index 2) in the struct
strides_val = llvm.extractvalue(
llvm.StructType.get_literal([i32_type] * 3),
self,
[2], # Extract the third field (index 2)
)
# Extract each dimension from the strides struct
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
class ExampleTensor:
"""A class representing a tensor with its data pointer, shape, and strides.
This class provides a Python interface to create and manipulate tensor structures
that can be passed to CUTE JIT compiled functions.
:ivar _c_struct_p: The C struct pointer for the tensor
:ivar _rank: The number of dimensions in the tensor
"""
def __init__(self, c_struct_p, rank):
"""Initialize a new Tensor.
:param c_struct_p: The C struct pointer for the tensor
:type c_struct_p: int
:param rank: The number of dimensions in the tensor
:type rank: int
"""
self._c_struct_p = c_struct_p
self._rank = rank
def __get_mlir_types__(self):
"""Get the MLIR types for this tensor.
Creates an LLVM structure type representing a C-structure with:
.. code-block:: c
struct Tensor {
void *ptr;
int32_t shape[3];
int32_t strides[3];
};
:return: A list containing the MLIR struct type
:rtype: list[llvm.StructType]
Create an LLVM structure type that represents a C-structure like:
"""
# Get the number of dimensions from the shape
ndim = self._rank
# Create the pointer type (void*)
ptr_type = llvm.PointerType.get()
# Create array types for shape and strides (int32_t[ndim])
int32_type = ir.IntegerType.get_signless(32)
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
# Create the structure type
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
return [struct_type]
def __new_from_mlir_values__(self, values):
"""Create a new TensorValue from MLIR values.
:param values: A list of MLIR values
:type values: list[ir.Value]
:return: A new TensorValue instance
:rtype: TensorValue
"""
return ExampleTensorValue(values[0])
def __c_pointers__(self):
"""Get the C pointers for this tensor.
:return: A list containing the C struct pointer
:rtype: list[int]
"""
return [self._c_struct_p]
@cute.jit
def foo(tensor):
"""Example JIT function that prints tensor information.
:param tensor: A Tensor instance to print information about
:type tensor: Tensor
"""
cute.printf("data_ptr: {}", tensor.data_ptr)
cute.printf("shape: {}", tensor.shape)
cute.printf("stride: {}", tensor.stride)
mA = cute.make_tensor(
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
)
cute.print_tensor(mA)
import sys
import os
import subprocess
import shutil
import tempfile
import torch
def run_test(tmpdir=None):
# Skip cleanup if user provides tmpdir
cleanup = tmpdir is None
# Initialize temporary build directory
tmpdir = tmpdir or tempfile.mkdtemp()
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
subprocess.run(["cmake", "--build", tmpdir], check=True)
sys.path.append(tmpdir)
from tensor import make_tensor, pycapsule_get_pointer
# Mock test tensor and corresponding C structure for this example
# In production, this may come from external library
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
c_struct_p = pycapsule_get_pointer(c_struct)
# Initialize tensor wrapper and compile test function
tensor = ExampleTensor(c_struct_p, len(x.shape))
compiled_func = cute.compile(foo, tensor)
# Benchmark pointer access performance
from time import time
start = time()
# Measure performance of critical path pointer access
# get C pointers is on critical path to call JIT compiled function
for _ in range(1000):
tensor.__c_pointers__()
end = time()
print(f"__c_pointers__: {(end - start) * 1000} us")
# Execute compiled function
compiled_func(tensor)
except Exception as e:
print(e)
finally:
if cleanup:
# Clean up the temporary directory
shutil.rmtree(tmpdir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Set temporary directory for building C modules"
)
parser.add_argument(
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
)
args = parser.parse_args()
run_test(args.tmp_dir)

View File

@ -0,0 +1,82 @@
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
#include <cstdint>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
// Forward declaration of the MockTensor struct for testing only
struct MockTensor {
void *ptr;
struct {
int32_t shape[3];
} shape;
struct {
int32_t strides[3];
} strides;
};
NB_MODULE(tensor, m) {
// create a tensor for testing
m.def("make_tensor", [](int64_t ptr, std::vector<int32_t> shape,
std::vector<int32_t> strides) {
auto *tensor = new MockTensor();
tensor->ptr = reinterpret_cast<void *>(ptr);
assert(shape.size() == 3 && "shape must have 3 elements");
assert(strides.size() == 3 && "strides must have 3 elements");
for (size_t i = 0; i < shape.size(); i++) {
tensor->shape.shape[i] = shape[i];
tensor->strides.strides[i] = strides[i];
}
return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) {
auto n = PyCapsule_GetName(capsule);
if (void *p = PyCapsule_GetPointer(capsule, n)) {
delete reinterpret_cast<MockTensor *>(p);
}
}));
});
m.def(
"pycapsule_get_pointer",
[](nb::object &capsule) {
void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor");
if (!ptr) {
throw std::runtime_error("Invalid tensor capsule");
}
return reinterpret_cast<uintptr_t>(ptr);
},
"Get pointer from PyCapsule");
}

File diff suppressed because it is too large Load Diff

View File

@ -83,11 +83,6 @@
"\n", "\n",
" # Print hello world from host code\n", " # Print hello world from host code\n",
" cute.printf(\"hello world\")\n", " cute.printf(\"hello world\")\n",
" \n",
" # Initialize CUDA context for launching a kernel with error checking\n",
" # We make context initialization explicit to allow users to control the context creation \n",
" # and avoid potential issues with multiple contexts\n",
" cutlass.cuda.initialize_cuda_context()\n",
"\n", "\n",
" # Launch kernel\n", " # Launch kernel\n",
" kernel().launch(\n", " kernel().launch(\n",
@ -129,6 +124,11 @@
} }
], ],
"source": [ "source": [
"# Initialize CUDA context for launching a kernel with error checking\n",
"# We make context initialization explicit to allow users to control the context creation \n",
"# and avoid potential issues with multiple contexts\n",
"cutlass.cuda.initialize_cuda_context()\n",
"\n",
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n", "# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
"print(\"Running hello_world()...\")\n", "print(\"Running hello_world()...\")\n",
"hello_world()\n", "hello_world()\n",
@ -136,6 +136,7 @@
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n", "# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
"print(\"Compiling...\")\n", "print(\"Compiling...\")\n",
"hello_world_compiled = cute.compile(hello_world)\n", "hello_world_compiled = cute.compile(hello_world)\n",
"\n",
"# Run the pre-compiled version\n", "# Run the pre-compiled version\n",
"print(\"Running compiled version...\")\n", "print(\"Running compiled version...\")\n",
"hello_world_compiled()" "hello_world_compiled()"

View File

@ -33,7 +33,6 @@
#include <cute/config.hpp> #include <cute/config.hpp>
#include <cute/tensor_impl.hpp> #include <cute/tensor_impl.hpp>
#include <cute/tensor_predicate.hpp>
namespace cute namespace cute
{ {
@ -45,7 +44,7 @@ template <class Alpha,
class XEngine, class XLayout, class XEngine, class XLayout,
class Beta, class Beta,
class YEngine, class YLayout, class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor> class PrdTensor = constant_fn<true_type>>
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void void
axpby(Alpha const& alpha, axpby(Alpha const& alpha,
@ -64,7 +63,7 @@ template <class Alpha,
class XEngine, class XLayout, class XEngine, class XLayout,
class Beta, class Beta,
class YEngine, class YLayout, class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor> class PrdTensor = constant_fn<true_type>>
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void void
axpby(Alpha const& alpha, axpby(Alpha const& alpha,

View File

@ -36,7 +36,6 @@
#include <cute/swizzle.hpp> // cute::Swizzle #include <cute/swizzle.hpp> // cute::Swizzle
#include <cute/swizzle_layout.hpp> // cute::get_nonswizzle_portion #include <cute/swizzle_layout.hpp> // cute::get_nonswizzle_portion
#include <cute/tensor_impl.hpp> // cute::Tensor #include <cute/tensor_impl.hpp> // cute::Tensor
#include <cute/tensor_predicate.hpp>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/atom/copy_atom.hpp> #include <cute/atom/copy_atom.hpp>

View File

@ -32,7 +32,6 @@
#include <cute/config.hpp> // CUTE_HOST_DEVICE #include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/tensor_impl.hpp> // cute::Tensor #include <cute/tensor_impl.hpp> // cute::Tensor
#include <cute/tensor_predicate.hpp> // cute::TrivialPredTensor
#include <cute/atom/copy_atom.hpp> // cute::Copy_Atom #include <cute/atom/copy_atom.hpp> // cute::Copy_Atom
namespace cute namespace cute
@ -66,10 +65,45 @@ copy_if(PrdTensor const& pred,
// copy_if -- Predicated CopyAtom // copy_if -- Predicated CopyAtom
// //
// Predicate Tensor is an Actual Tensor
template <class... CopyArgs,
class PrdEngine, class PrdLayout,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<PrdEngine, PrdLayout> const& prd, // ([V],Rest...)
Tensor<SrcEngine, SrcLayout> const& src, // ( V, Rest...)
Tensor<DstEngine, DstLayout> & dst) // ( V, Rest...)
{
if constexpr (PrdLayout::rank == SrcLayout::rank - 1) {
// Back-compat ONLY -- Delete?
copy_if(copy_atom, make_tensor(prd.data(), prepend(prd.layout(), Layout<_1,_0>{})), src, dst);
} else {
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
static_assert(SrcLayout::rank == PrdLayout::rank, "CopyAtom rank-mismatch.");
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
copy_atom.call(prd, src, dst);
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
Tensor prd_v = group_modes<1,R>(prd);
Tensor src_v = group_modes<1,R>(src);
Tensor dst_v = group_modes<1,R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
copy_atom.call(prd_v(_,i), src_v(_,i), dst_v(_,i));
}
}
}
}
template <class... CopyArgs, template <class... CopyArgs,
class PredTensor, class PredTensor,
class SrcEngine, class SrcLayout, class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout> class DstEngine, class DstLayout>
[[deprecated("Use a bool-tensor or transform-tensor as predication.")]]
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void void
copy_if(Copy_Atom<CopyArgs...> const& copy_atom, copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
@ -77,33 +111,14 @@ copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...) Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...) Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{ {
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); Tensor tpred = cute::lazy::transform(make_tensor(counting_iterator<int>{}, replace<0>(shape(dst), _1{})), pred);
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom); return copy_if(copy_atom, tpred, src, dst);
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
if constexpr (has_with_bool) {
copy_atom.with(pred()).call(src, dst);
} else {
if (pred()) { copy_atom.call(src, dst); }
}
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
Tensor src_v = group_modes<1,R>(src);
Tensor dst_v = group_modes<1,R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
if constexpr (has_with_bool) {
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
} else {
if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); }
}
}
}
} }
// //
// copy_if -- AutoCopyAsync // copy_if -- AutoCopyAsync
// //
template <class PrdTensor, template <class PrdTensor,
class SrcEngine, class SrcLayout, class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout> class DstEngine, class DstLayout>
@ -159,7 +174,7 @@ copy(AutoCopyAsync const& cpy,
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...) Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...) Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{ {
copy_if(cpy, TrivialPredTensor{}, src, dst); copy_if(cpy, constant_fn<true_type>{}, src, dst);
} }
// //
@ -202,7 +217,7 @@ copy(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c)); CUTE_STATIC_ASSERT_V( size<1>(src_c) == size<1>(dst_c));
CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst));
CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src));
@ -224,7 +239,7 @@ copy(Copy_Atom<CopyArgs...> const& copy_atom,
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits> // Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits>
template <int MaxVecBits, class... Args, template <int MaxVecBits,
class SrcEngine, class SrcLayout, class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout> class DstEngine, class DstLayout>
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
@ -234,23 +249,30 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
Tensor<DstEngine, DstLayout> & dst) Tensor<DstEngine, DstLayout> & dst)
{ {
constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{})); static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename DstEngine::value_type>)>::value, "Error: Attempting a subbit write!");
static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename SrcEngine::value_type>)>::value, "Error: Attempting a subbit copy!");
constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) { if constexpr (common_elem > 1)
// If more than one element vectorizes to 8bits or more, then recast and copy {
using VecType = uint_bit_t<vec_bits>; constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
// Preserve volatility constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
// Recast if constexpr ((vec_bits % 8) == 0)
Tensor src_v = recast<SrcVecType>(src); {
Tensor dst_v = recast<DstVecType>(dst); // If more than one element vectorizes to 8bits or more, then recast and copy
return copy_if(TrivialPredTensor{}, src_v, dst_v); using VecType = uint_bit_t<vec_bits>;
// Preserve volatility
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
// Recast
Tensor src_v = recast<SrcVecType>(src);
Tensor dst_v = recast<DstVecType>(dst);
return copy_if(constant_fn<true_type>{}, src_v, dst_v);
} else {
return copy_if(constant_fn<true_type>{}, src, dst);
}
} else { } else {
return copy_if(TrivialPredTensor{}, src, dst); return copy_if(constant_fn<true_type>{}, src, dst);
} }
} }
@ -277,7 +299,7 @@ copy(AutoFilter<CopyOp> const& copy_op,
Tensor src_n = zipped_divide(src, dst_null); Tensor src_n = zipped_divide(src, dst_null);
CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error");
CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy"); CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous race-condition detected.");
copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_));
} else { } else {
@ -335,6 +357,18 @@ copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, Args...> con
return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst); return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst);
} }
template <int MaxVecBits, class... Args,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<Copy_Traits<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>>, Args...> const&,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst);
}
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
template <class... CT_Args, template <class... CT_Args,
class SrcEngine, class SrcLayout, class SrcEngine, class SrcLayout,
@ -375,8 +409,8 @@ template <class... CT_Args, class... CA_Args,
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void void
copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom, copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src, Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst) Tensor<DstEngine, DstLayout> & dst)
{ {
return copy(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src, dst); return copy(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src, dst);
} }

View File

@ -90,18 +90,19 @@ constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
} // end namespace detail } // end namespace detail
template <class CopyOp, class... CT_Args, class... CA_Args, template <class CopyOp, class... CT_Args, class CopyType,
class GEngine, class GLayout> class GEngine, class GLayout>
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void void
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CA_Args...> const& atom, prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CopyType> const& atom,
Tensor<GEngine, GLayout> const& src) Tensor<GEngine, GLayout> const& src)
{ {
if constexpr (detail::has_prefetch<CopyOp>) { if constexpr (detail::has_prefetch<CopyOp>) {
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>; using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CA_Args...>; using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CopyType>;
Prefetch_Atom prefetch_atom{atom}; Prefetch_Atom prefetch_atom{atom};
auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms //auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms
Tensor dst = make_tensor(make_smem_ptr<CopyType>(nullptr), shape(src));
return copy(prefetch_atom, src, dst); return copy(prefetch_atom, src, dst);
} else { } else {
return prefetch(src); return prefetch(src);

View File

@ -163,4 +163,16 @@ transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
return transform(tensor_in1, tensor_in2, tensor_out, op); return transform(tensor_in1, tensor_in2, tensor_out, op);
} }
namespace lazy {
template <class Engine, class Layout, class Fn>
CUTE_HOST_DEVICE constexpr
auto
transform(cute::Tensor<Engine,Layout> const& t, Fn const& fn)
{
return cute::make_tensor(cute::make_transform_iter(fn, t.data()), t.layout());
}
} // end namespace lazy
} // end namespace cute } // end namespace cute

View File

@ -0,0 +1,107 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <cute/config.hpp>
#include <cute/tensor_impl.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/fill.hpp>
namespace cute
{
// Reduce @src tensor using binary reduction operator @op and initial value @init and return a scalar.
template <class SrcEngine, class SrcLayout, class T, class BinaryOp = cute::plus>
CUTE_HOST_DEVICE constexpr
T
reduce(Tensor<SrcEngine,SrcLayout> const& src, T init, BinaryOp op = {})
{
for (auto i = 0; i < size(src); ++i) {
init = op(init, src(i));
}
return init;
}
// Reduce @src tensor RedMode using binary reduction operator @op and store the result in @dst tensor
// for each index in @dst/BatchMode.
// @pre @src tensor has rank 2
// @pre size of @src batch mode is equal to size of @dst batch mode
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout,
class BinaryOp = cute::plus>
CUTE_HOST_DEVICE constexpr
void
batch_reduce(Tensor<SrcEngine, SrcLayout> const& src, // (RedMode, BatchMode)
Tensor<DstEngine, DstLayout> & dst, // (BatchMode)
BinaryOp op = {})
{
// Precondition
CUTE_STATIC_ASSERT_V(rank(src) == Int<2>{});
assert(size<1>(src) == size(dst));
for (int i = 0; i < size(dst); ++i) {
dst(i) = reduce(src(_,i), dst(i), op);
}
}
// Reduce @src tensor along selected modes specified in @target_profile using binary reduction operator @op
// and store the result in @dst tensor. @target_profile is a tuple where '_' indicates modes to keep and
// integers indicates modes to reduce.
// @pre @target_profile is compatible with @src layout
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout,
class TargetProfile,
class BinaryOp = cute::plus>
CUTE_HOST_DEVICE constexpr
void
logical_reduce(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
TargetProfile const& target_profile,
BinaryOp op = {})
{
// Precondition
assert(compatible(target_profile, shape(src)));
auto diced_layout = dice(target_profile, src.layout());
auto sliced_layout = slice(target_profile, src.layout());
auto red_mode = conditional_return<rank(diced_layout) == Int<0>{}>(Layout<_1,_0>{}, diced_layout);
auto batch_mode = conditional_return<rank(sliced_layout) == Int<0>{}>(Layout<_1,_0>{}, sliced_layout);
auto src_tensor = make_tensor(src.data(), make_layout(red_mode, batch_mode));
batch_reduce(src_tensor, dst, op);
}
} // end namespace cute

View File

@ -123,6 +123,56 @@ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
{ {
return call(src, dst); return call(src, dst);
} }
// Check and call instruction, or recurse
template <class PEngine, class PLayout,
class SEngine, class SLayout,
class DEngine, class DLayout>
CUTE_HOST_DEVICE
void
call(Tensor<PEngine,PLayout> const& prd,
Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> & dst) const
{
static_assert(PLayout::rank == 1, "Expected rank-1 prd tensor");
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor");
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor");
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value ||
is_constant<NumValDst, decltype(size(dst))>::value) {
// Dispatch to unpack to execute instruction
Traits const& traits = static_cast<Traits const&>(*this);
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(t.with(true))>{}, traits);
if constexpr (has_with_bool) {
copy_unpack(traits.with(prd(Int<0>{})), src, dst);
} else {
if (prd(Int<0>{})) { copy_unpack(traits, src, dst); }
}
} else if constexpr (is_tuple<decltype(shape(prd))>::value &&
is_tuple<decltype(shape(src))>::value &&
is_tuple<decltype(shape(dst))>::value) {
// If the size of the src/dst doesn't match the instruction,
// recurse this rank-1 layout by peeling off the mode
// ((A,B,C,...)) -> (A,B,C,...)
return copy_if(*this, tensor<0>(prd), tensor<0>(src), tensor<0>(dst));
} else {
static_assert(dependent_false<SEngine>,
"CopyAtom: Src/Dst partitioning does not match the instruction requirement.");
}
}
// Accept mutable temporaries
template <class PEngine, class PLayout,
class SEngine, class SLayout,
class DEngine, class DLayout>
CUTE_HOST_DEVICE
void
call(Tensor<PEngine,PLayout> const& prd,
Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> && dst) const
{
return call(prd, src, dst);
}
}; };
// //
@ -733,13 +783,13 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
#include <cute/atom/copy_traits_sm75.hpp> #include <cute/atom/copy_traits_sm75.hpp>
#include <cute/atom/copy_traits_sm80.hpp> #include <cute/atom/copy_traits_sm80.hpp>
#include <cute/atom/copy_traits_sm90.hpp> #include <cute/atom/copy_traits_sm90.hpp>
#include <cute/atom/copy_traits_sm100.hpp> #include <cute/atom/copy_traits_sm100.hpp>
// Config // Config
#if (__CUDACC_VER_MAJOR__ >= 12) #if (__CUDACC_VER_MAJOR__ >= 12)
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED # define CUTE_COPY_ATOM_TMA_SM90_ENABLED
# define CUTE_COPY_ATOM_TMA_SM100_ENABLED # define CUTE_COPY_ATOM_TMA_SM100_ENABLED
#endif #endif

View File

@ -235,6 +235,63 @@ raw_pointer_cast(counting_iterator<T> const& x) {
return x.n_; return x.n_;
} }
//
// transform_iterator
//
template <class Fn, class Iter>
struct transform_iter
{
using iterator = Iter;
// using reference = typename iterator_traits<iterator>::reference;
// using element_type = typename iterator_traits<iterator>::element_type;
// using value_type = typename iterator_traits<iterator>::value_type;
Fn fn_;
iterator ptr_;
CUTE_HOST_DEVICE constexpr
transform_iter(Fn fn, iterator ptr = {}) : fn_(fn), ptr_(ptr) {}
CUTE_HOST_DEVICE constexpr
decltype(auto) operator*() const { return fn_(*ptr_); }
template <class Index>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator[](Index const& i) const { return fn_(ptr_[i]); }
template <class Index>
CUTE_HOST_DEVICE constexpr
auto operator+(Index const& i) const { return transform_iter<Fn, decltype(ptr_+i)>{fn_, ptr_+i}; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator==(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ == y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator!=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ != y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator< (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ < y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator<=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ <= y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator> (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ > y.ptr_; }
template <class IterY>
CUTE_HOST_DEVICE constexpr
friend bool operator>=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ >= y.ptr_; }
};
template <class Fn, class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_transform_iter(Fn const& fn, Iterator const& ptr)
{
return transform_iter<Fn,Iterator>(fn,ptr);
}
// //
// Display utilities // Display utilities
// //
@ -251,12 +308,24 @@ CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
printf("counting_iter("); print(ptr.n_); printf(")"); printf("counting_iter("); print(ptr.n_); printf(")");
} }
template <class Fn, class Iterator>
CUTE_HOST_DEVICE void print(transform_iter<Fn,Iterator> ptr)
{
printf("trans_"); print(ptr.ptr_);
}
#if !defined(__CUDACC_RTC__) #if !defined(__CUDACC_RTC__)
template <class T> template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr) CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
{ {
return os << "counting_iter(" << ptr.n_ << ")"; return os << "counting_iter(" << ptr.n_ << ")";
} }
template <class Fn, class Iterator>
CUTE_HOST std::ostream& operator<<(std::ostream& os, transform_iter<Fn,Iterator> ptr)
{
return os << "trans_" << ptr.ptr_;
}
#endif // !defined(__CUDACC_RTC__) #endif // !defined(__CUDACC_RTC__)
} // end namespace cute } // end namespace cute

View File

@ -41,7 +41,8 @@
#include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/dispatch_policy.hpp"
#ifndef CUTLASS_GDC_ENABLED #ifndef CUTLASS_GDC_ENABLED
#if (defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ #if (CUDA_BARRIER_ENABLED && \
defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \
__CUDACC_VER_MAJOR__ >= 12 && \ __CUDACC_VER_MAJOR__ >= 12 && \
defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL))
#define CUTLASS_GDC_ENABLED #define CUTLASS_GDC_ENABLED

View File

@ -43,7 +43,6 @@
#include "cute/arch/cluster_sm90.hpp" #include "cute/arch/cluster_sm90.hpp"
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp" #include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp" #include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/trace.h" #include "cutlass/trace.h"

View File

@ -32,7 +32,6 @@
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cute/tensor_predicate.hpp"
#include "cute/arch/cluster_sm90.hpp" #include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp" #include "cute/arch/copy_sm90.hpp"
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
@ -103,7 +102,7 @@ struct CollectiveConv<
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>; using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>; using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
@ -332,7 +331,7 @@ public:
TmaTransactionBytes TmaTransactionBytes
}; };
} }
template <class ProblemShape> template <class ProblemShape>
static bool static bool
can_implement( can_implement(
@ -409,7 +408,7 @@ public:
if constexpr (ConvOp == conv::Operator::kWgrad) { if constexpr (ConvOp == conv::Operator::kWgrad) {
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
std::ostringstream os; std::ostringstream os;
#endif #endif
const auto & input_shape = problem_shape.shape_A; const auto & input_shape = problem_shape.shape_A;
const auto & input_stride = problem_shape.stride_A; const auto & input_stride = problem_shape.stride_A;
@ -431,11 +430,11 @@ public:
<< "\n input_shape: " << input_shape << "\n input_shape: " << input_shape
<< "\n input_stride: " << input_stride << "\n input_stride: " << input_stride
<< "\n"; << "\n";
#endif #endif
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n");
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
CUTLASS_TRACE_HOST(os.str()); CUTLASS_TRACE_HOST(os.str());
#endif #endif
return false; return false;
} }
@ -464,7 +463,7 @@ public:
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
CUTLASS_TRACE_HOST(os.str()); CUTLASS_TRACE_HOST(os.str());
#endif #endif
return false; return false;
} }
} }
@ -516,8 +515,8 @@ public:
/// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k)
/// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k)
/// The rest of the tensors can be specified as needed by this collective. /// The rest of the tensors can be specified as needed by this collective.
/// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with
/// StrideA and StrideB set up for TMA /// StrideA and StrideB set up for TMA
template <class ProblemShapeMNKL> template <class ProblemShapeMNKL>
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){

View File

@ -303,6 +303,16 @@ public:
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
// Dynamic cluster support
[[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0};
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
ConvKernel::ArchTag::kMinComputeCapability == 101) {
if constexpr (!cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape>) {
fallback_cluster = params.hw_info.cluster_shape_fallback;
cluster = params.hw_info.cluster_shape;
}
}
void* kernel_params[] = {&params}; void* kernel_params[] = {&params};
if constexpr (kEnableCudaHostAdapter) { if constexpr (kEnableCudaHostAdapter) {
// //
@ -313,6 +323,7 @@ public:
launch_result = cuda_adapter->launch(grid, launch_result = cuda_adapter->launch(grid,
cluster, cluster,
fallback_cluster,
block, block,
smem_size, smem_size,
stream, stream,
@ -338,6 +349,20 @@ public:
grid, cluster, block, smem_size, stream, kernel, kernel_params); grid, cluster, block, smem_size, stream, kernel, kernel_params);
} }
} }
else {
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
ConvKernel::ArchTag::kMinComputeCapability == 101) {
launch_result = ClusterLauncher::launch_with_fallback_cluster(
grid,
cluster,
fallback_cluster,
block,
smem_size,
stream,
kernel,
kernel_params);
}
}
} }
} }
else { else {

View File

@ -48,7 +48,7 @@ namespace cutlass::detail{
using namespace cute; using namespace cute;
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN> template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
struct Sm100BlockwiseScaleConfig { struct Sm1xxBlockwiseScaleConfig {
using ShapeSFA = Shape<Shape<Int<SFVecSizeM>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>; using ShapeSFA = Shape<Shape<Int<SFVecSizeM>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
using ShapeSFB = Shape<Shape<Int<SFVecSizeN>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>; using ShapeSFB = Shape<Shape<Int<SFVecSizeN>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
@ -271,7 +271,18 @@ struct RuntimeBlockwiseScaleConfig {
// Sm90 only supports MN major for SFA and SFB for now // Sm90 only supports MN major for SFA and SFB for now
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK> template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK>
using Sm90BlockwiseScaleConfig = Sm100BlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK>; using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK>;
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK, majorSFA, majorSFB>;
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
using Sm120BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK, majorSFA, majorSFB>;
template<class MmaTileShape_MNK>
constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) {
return Sm90BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
}
template<class MmaTileShape_MNK> template<class MmaTileShape_MNK>
constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
@ -279,8 +290,8 @@ constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
} }
template<class MmaTileShape_MNK> template<class MmaTileShape_MNK>
constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { constexpr auto sm120_trivial_blockwise_scale_config(MmaTileShape_MNK) {
return Sm90BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; return Sm120BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -371,11 +371,14 @@ template <
constexpr int constexpr int
get_input_alignment_bits() { get_input_alignment_bits() {
if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 4) { if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 4) {
// 16U4 format: The inner tensor size dimension should be multiple of 64B.
return 64 * 8; return 64 * 8;
} }
else if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 6) { else if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 6) {
// 16U6 format : The inner tensor size dimension must be a multiple of 96B.
return 96 * 8; return 96 * 8;
} }
// TMA 16B alignment requirement
return 128; return 128;
} }
@ -383,12 +386,11 @@ get_input_alignment_bits() {
template <class ElementType> template <class ElementType>
constexpr int constexpr int
get_output_alignment_bits() { get_output_alignment_bits() {
if constexpr (sizeof_bits<ElementType>::value == 6) { if constexpr (sizeof_bits<ElementType>::value == 6) {
// U6 format : The inner tensor size dimension must be a multiple of 96B. // 16U6 format : The inner tensor size dimension must be a multiple of 96B.
return 96 * 8; return 96 * 8;
} }
// TMA 16B alignment requirement
return 128; return 128;
} }

View File

@ -981,6 +981,9 @@ private:
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule>; static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule>;
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
// C/D should meet TMA alignment requirement if not void
static_assert(detail::is_aligned<ElementC_, AlignmentC, ElementD_, AlignmentD>(),
"C/D Should meet TMA alignment requirement\n");
static constexpr bool DisableDestination = cute::is_void_v<ElementD_>; static constexpr bool DisableDestination = cute::is_void_v<ElementD_>;
using ElementD = cute::conditional_t<DisableDestination,fusion::get_element_aux_t<FusionOpOrCallbacks>,ElementD_>; // prevents void ref breakages using ElementD = cute::conditional_t<DisableDestination,fusion::get_element_aux_t<FusionOpOrCallbacks>,ElementD_>; // prevents void ref breakages

View File

@ -293,6 +293,9 @@ template <
class DispatchPolicy class DispatchPolicy
> >
struct Sm90TmaBuilderImpl { struct Sm90TmaBuilderImpl {
// C/D should meet TMA alignment requirement if not void
static_assert(detail::is_aligned<ElementC_, AlignmentC, ElementD_, AlignmentD>(),
"C/D Should meet TMA alignment requirement\n");
// Passing void D disables destination store + smem allocation // Passing void D disables destination store + smem allocation
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>, using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>; fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>;

View File

@ -91,6 +91,14 @@ sm90_get_smem_load_op_for_source() {
} }
} }
// C/D should meet TMA alignment requirement if not void
template <class ElementC, int AlignmentC, class ElementD, int AlignmentD>
constexpr bool
is_aligned() {
return (cute::is_void_v<ElementC> || (cute::sizeof_bits_v<ElementC> * AlignmentC) % cutlass::detail::get_output_alignment_bits<ElementC>() == 0) &&
(cute::is_void_v<ElementD> || (cute::sizeof_bits_v<ElementD> * AlignmentD) % cutlass::detail::get_output_alignment_bits<ElementD>() == 0);
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::epilogue::collective::detail } // namespace cutlass::epilogue::collective::detail

View File

@ -217,6 +217,16 @@ struct IsThreadEpilogueOpWithActivation <ThreadEpilogueOp, cute::enable_if_t<Thr
using type = typename ThreadEpilogueOp::ActivationFn; using type = typename ThreadEpilogueOp::ActivationFn;
}; };
template <typename ThreadEpilogueOp, typename = void>
struct IsThreadEpilogueOpWithPerChannelScaled {
static constexpr bool value = false;
};
template <typename ThreadEpilogueOp>
struct IsThreadEpilogueOpWithPerChannelScaled <ThreadEpilogueOp, cute::void_t<decltype(ThreadEpilogueOp::IsPerRowScaleSupported)>> {
static constexpr bool value = ThreadEpilogueOp::IsPerRowScaleSupported || ThreadEpilogueOp::IsPerColScaleSupported;
};
template <typename ThreadEpilogueOp, typename = void> template <typename ThreadEpilogueOp, typename = void>
struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {};

View File

@ -57,7 +57,7 @@ struct IsDefaultFusionOp {
}; };
template< template<
class ElementD, class ElementCompute, class ElementD, class ElementCompute,
class ElementC, FloatRoundStyle RoundStyle class ElementC, FloatRoundStyle RoundStyle
> >
struct IsDefaultFusionOp< struct IsDefaultFusionOp<
@ -69,7 +69,7 @@ struct IsDefaultFusionOp<
template< template<
class ElementOutput, int Count, class ElementAccumulator, class ElementOutput, int Count, class ElementAccumulator,
class ElementCompute, epilogue::thread::ScaleType::Kind Scale, class ElementCompute, epilogue::thread::ScaleType::Kind Scale,
FloatRoundStyle Round, class ElementSource FloatRoundStyle Round, class ElementSource
> >
struct IsDefaultFusionOp< struct IsDefaultFusionOp<
@ -133,7 +133,7 @@ public:
constexpr static int ThreadCount = 128; constexpr static int ThreadCount = 128;
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value; constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
constexpr static uint32_t TmaTransactionBytes = 0; constexpr static uint32_t TmaTransactionBytes = 0;
@ -240,7 +240,7 @@ public:
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tTR_gD)); // (T2R,T2R_M,T2R_N)
Tensor tTR_rC = make_tensor<GmemElementC>(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) Tensor tTR_rC = make_tensor<GmemElementC>(shape(tTR_gC)); // (T2R,T2R_M,T2R_N)
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
@ -250,7 +250,7 @@ public:
Tensor tTR_rD_frag = make_tensor<ElementD>(shape(tTR_rAcc)); Tensor tTR_rD_frag = make_tensor<ElementD>(shape(tTR_rAcc));
Tensor tTR_rD_src = recast<Array<ElementD, VD>>(coalesce(tTR_rD_frag)); Tensor tTR_rD_src = recast<Array<ElementD, VD>>(coalesce(tTR_rD_frag));
Tensor tR2G_rD_dst = recast<Array<ElementD, VD>>(coalesce(tTR_gD)); Tensor tR2G_rD_dst = recast<Array<ElementD, VD>>(coalesce(tTR_gD));
Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int<VD>{}))); Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int<VD>{})));
Tensor tDpD = make_tensor<bool>(shape(tR2G_rD_dst)); Tensor tDpD = make_tensor<bool>(shape(tR2G_rD_dst));
@ -325,7 +325,7 @@ public:
copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); copy_if(tDpD, tTR_rD_src, tR2G_rD_dst);
} }
// source is not needed, avoid load // source is not needed, avoid load
else else
{ {
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i++) { for (int i = 0; i < size(tTR_rAcc); i++) {
@ -382,7 +382,7 @@ public:
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
@ -498,7 +498,7 @@ public:
// Constructor and Data Members // Constructor and Data Members
// //
CUTLASS_DEVICE CUTLASS_DEVICE
CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors)
: fusion_callbacks(params_.thread, shared_tensors.thread) : fusion_callbacks(params_.thread, shared_tensors.thread)
, smem_buffer_ptr(shared_tensors.buffer.data()) , smem_buffer_ptr(shared_tensors.buffer.data())
, params(params_) {}; , params(params_) {};
@ -506,7 +506,7 @@ public:
protected: protected:
FusionCallbacks fusion_callbacks; FusionCallbacks fusion_callbacks;
uint8_t* smem_buffer_ptr; uint8_t* smem_buffer_ptr;
Params const& params; Params const& params;
public: public:
@ -543,7 +543,7 @@ public:
can_implement( can_implement(
[[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) { [[maybe_unused]] Arguments const& args) {
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
if (!fusion_implementable) { if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
@ -636,7 +636,7 @@ public:
Tensor tTR_rD_frg = recast<Array<ElementD, FragmentSize>>(coalesce(tTR_rD)); Tensor tTR_rD_frg = recast<Array<ElementD, FragmentSize>>(coalesce(tTR_rD));
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl, problem_shape_mnkl,
cta_tile_mnk, cta_tile_mnk,
cta_coord_mnkl, cta_coord_mnkl,
int(0), int(0),
@ -693,20 +693,17 @@ public:
} }
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); });
cst_callbacks.begin_loop(epi_m, epi_n); cst_callbacks.begin_loop(epi_m, epi_n);
if constexpr (not cute::is_void_v<ElementC>) { if constexpr (not cute::is_void_v<ElementC>) {
if (is_C_load_needed) { if (is_C_load_needed) {
using CVecType = uint_bit_t<VC * sizeof_bits_v<ElementC>>; using CVecType = uint_bit_t<VC * sizeof_bits_v<ElementC>>;
Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int<VC>{})));
auto pred_fn_C = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
return elem_less(tTR_cC_frag(coords...), problem_shape_mnl);
};
Tensor tTR_gC_frg = recast<CVecType>(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); Tensor tTR_gC_frg = recast<CVecType>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
Tensor tTR_rC_frg = recast<CVecType>(coalesce(tCrC)); Tensor tTR_rC_frg = recast<CVecType>(coalesce(tCrC));
copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg); Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int<VC>{})));
copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg);
} }
} }
@ -717,7 +714,7 @@ public:
Tensor tTR_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc)); Tensor tTR_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc);
// After the last tmem load, signal that tmem buffer is consumed and empty // After the last tmem load, signal that tmem buffer is consumed and empty
if (do_acc_release) { if (do_acc_release) {
cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::fence_view_async_tmem_load();
@ -737,16 +734,11 @@ public:
cst_callbacks.end_loop(epi_m, epi_n); cst_callbacks.end_loop(epi_m, epi_n);
Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int<VD>{})));
auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
return elem_less(tTR_cD_frag(coords...), problem_shape_mnl);
};
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>; using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD)); Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg); Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int<VD>{})));
copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg);
} // for epi_m } // for epi_m
} // for epi_n } // for epi_n

View File

@ -59,7 +59,7 @@ template <
> >
class Epilogue { class Epilogue {
static_assert(cute::is_same_v<EpilogueScheduleType, EpilogueSimtVectorized> || static_assert(cute::is_same_v<EpilogueScheduleType, EpilogueSimtVectorized> ||
cute::is_same_v<EpilogueScheduleType, EpiloguePtrArraySimtVectorized>, cute::is_same_v<EpilogueScheduleType, EpiloguePtrArraySimtVectorized>,
"Could not find an epilogue specialization."); "Could not find an epilogue specialization.");
}; };
@ -141,7 +141,7 @@ public:
ElementScalar const* beta_ptr = nullptr; ElementScalar const* beta_ptr = nullptr;
ElementBias const* bias_ptr = nullptr; ElementBias const* bias_ptr = nullptr;
StrideBias dBias{}; StrideBias dBias{};
}; };
template<class ThreadEpiOp> template<class ThreadEpiOp>
struct ThreadEpilogueOpArguments< struct ThreadEpilogueOpArguments<
@ -202,7 +202,7 @@ public:
to_underlying_arguments( to_underlying_arguments(
[[maybe_unused]] ProblemShape const& _, [[maybe_unused]] ProblemShape const& _,
Arguments const& args, Arguments const& args,
[[maybe_unused]] void* workspace) { [[maybe_unused]] void* workspace) {
typename ThreadEpilogueOp::Params thread_op_args; typename ThreadEpilogueOp::Params thread_op_args;
thread_op_args.alpha = args.thread.alpha; thread_op_args.alpha = args.thread.alpha;
thread_op_args.beta = args.thread.beta; thread_op_args.beta = args.thread.beta;
@ -317,7 +317,7 @@ public:
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
// Construct a tensor in SMEM that we can partition for rearranging data // Construct a tensor in SMEM that we can partition for rearranging data
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf); SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N)
@ -389,10 +389,10 @@ public:
Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); Tensor tSR_gBias_flt = filter_zeros(tSR_gBias);
Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); Tensor tSR_rBias_flt = filter_zeros(tSR_rBias);
Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride());
Tensor tSR_pD_flt = cute::lazy::transform(tSR_cD_flt, [&](auto const& c){ return elem_less(c, take<0,2>(residue_mnk)); });
// Step 0. Copy Bias from GMEM to fragment // Step 0. Copy Bias from GMEM to fragment
auto pred_fn = [&] (auto const&... coords) { return elem_less(tSR_cD_flt(coords...), take<0, 2>(residue_mnk)); }; copy_if(tSR_pD_flt, tSR_gBias_flt, tSR_rBias_flt);
copy_if(pred_fn, tSR_gBias_flt, tSR_rBias_flt);
} }
} }

View File

@ -560,18 +560,18 @@ struct Sm90TreeVisitor<
Tensor tC_rAux_vec = recast<VecType>(tC_rAux); Tensor tC_rAux_vec = recast<VecType>(tC_rAux);
Tensor tC_gAux_vec = recast<VecType>(tC_gAux); Tensor tC_gAux_vec = recast<VecType>(tC_gAux);
Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{}))); Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{})));
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
copy_if(predicate_fn, tC_rAux_vec, tC_gAux_vec); copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec);
} }
// sub-byte vectorization, must serialize threads // sub-byte vectorization, must serialize threads
else { else {
// Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this)
int lane_idx = canonical_lane_idx(); int lane_idx = canonical_lane_idx();
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (int i = 0; i < NumThreadsPerWarp; ++i) { for (int i = 0; i < NumThreadsPerWarp; ++i) {
if (lane_idx == i) { if (lane_idx == i) {
copy_if(predicate_fn, tC_rAux, tC_gAux); copy_if(tC_pAux, tC_rAux, tC_gAux);
} }
__syncwarp(); __syncwarp();
} }
@ -719,12 +719,12 @@ struct Sm90AuxLoad<
Tensor tC_gAux_vec = recast<VecType>(tC_gAux); Tensor tC_gAux_vec = recast<VecType>(tC_gAux);
Tensor tC_rAux_vec = recast<VecType>(tC_rAux); Tensor tC_rAux_vec = recast<VecType>(tC_rAux);
Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{}))); Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{})));
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
copy_if(predicate_fn, tC_gAux_vec, tC_rAux_vec); copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec);
} }
else { else {
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
copy_if(predicate_fn, tC_gAux, tC_rAux); copy_if(tC_pAux, tC_gAux, tC_rAux);
} }
} }
} }
@ -738,8 +738,8 @@ struct Sm90AuxLoad<
} }
} }
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_tC_cAux); }; Tensor tC_pAux = cute::lazy::transform(tC_cAux(_,_,_,epi_m,epi_n), [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
copy_if(predicate_fn, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); copy_if(tC_pAux, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux);
} }
} }

View File

@ -449,7 +449,7 @@ template <
bool EnableNullptr bool EnableNullptr
> >
struct Sm90AuxLoad< struct Sm90AuxLoad<
0, EpilogueTile, Element, LayoutOrStrideMNL, 0, EpilogueTile, Element, LayoutOrStrideMNL,
SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr
> { > {
using ElementAux = Element; using ElementAux = Element;
@ -496,7 +496,7 @@ struct Sm90AuxLoad<
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { } : params_ptr(&params) { }
Params const* params_ptr; Params const* params_ptr;
CUTLASS_DEVICE bool CUTLASS_DEVICE bool
@ -533,7 +533,7 @@ struct Sm90AuxLoad<
tC_cAux(cute::forward<CTensorG2R>(tC_cAux)), tC_cAux(cute::forward<CTensorG2R>(tC_cAux)),
problem_shape_mnl(problem_shape_mnl), problem_shape_mnl(problem_shape_mnl),
params_ptr(params_ptr) {} params_ptr(params_ptr) {}
GTensorG2R tC_gAux; GTensorG2R tC_gAux;
RTensor tC_rAux; RTensor tC_rAux;
CTensorG2R tC_cAux; CTensorG2R tC_cAux;
@ -551,17 +551,13 @@ struct Sm90AuxLoad<
constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){};
constexpr int V = cute::min(Alignment, size(MCL)); constexpr int V = cute::min(Alignment, size(MCL));
Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n);
Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int<V>{})));
Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n)));
Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux)); Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux));
auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int<V>{})));
return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); });
};
copy_if(pred_fn, tC_gAux_vec, tC_rAux_vec); copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec);
} }
template <typename ElementAccumulator, int FragmentSize> template <typename ElementAccumulator, int FragmentSize>
@ -647,7 +643,7 @@ struct Sm90ScalarBroadcast {
can_implement(ProblemShape const& problem_shape, Arguments const& args) { can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true; return true;
} }
template <class ProblemShape> template <class ProblemShape>
static size_t static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
@ -674,11 +670,11 @@ struct Sm90ScalarBroadcast {
// This must be called after update_scalar is called // This must be called after update_scalar is called
CUTLASS_DEVICE bool CUTLASS_DEVICE bool
is_zero() const { is_zero() const {
if (get<2>(params_ptr->dScalar[0]) == 0) { if (get<2>(params_ptr->dScalar[0]) == 0) {
// Only 1 batch // Only 1 batch
return scalar == Element(0); return scalar == Element(0);
} }
else { else {
// multiple batch // multiple batch
if (valid_scalar == false) { if (valid_scalar == false) {
// for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps.
@ -761,7 +757,7 @@ private:
if (params_ptr->scalar_ptrs[0] != nullptr) { if (params_ptr->scalar_ptrs[0] != nullptr) {
scalar = params_ptr->scalar_ptrs[0][l_offset]; scalar = params_ptr->scalar_ptrs[0][l_offset];
} }
else { else {
// batch stride is ignored for nullptr fallback // batch stride is ignored for nullptr fallback
scalar = params_ptr->scalars[0]; scalar = params_ptr->scalars[0];
@ -774,7 +770,7 @@ private:
if (params_ptr->scalar_ptrs[i] != nullptr) { if (params_ptr->scalar_ptrs[i] != nullptr) {
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
} }
else { else {
// batch stride is ignored for nullptr fallback // batch stride is ignored for nullptr fallback
scalar = reduction_fn(scalar, params_ptr->scalars[i]); scalar = reduction_fn(scalar, params_ptr->scalars[i]);
@ -826,7 +822,7 @@ struct Sm90ScalarBroadcastPtrArray {
can_implement(ProblemShape const& problem_shape, Arguments const& args) { can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true; return true;
} }
template <class ProblemShape> template <class ProblemShape>
static size_t static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
@ -946,7 +942,7 @@ private:
if (params_ptr->scalar_ptrs[i] != nullptr) { if (params_ptr->scalar_ptrs[i] != nullptr) {
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
} }
else { else {
// batch stride is ignored for nullptr fallback // batch stride is ignored for nullptr fallback
scalar = reduction_fn(scalar, params_ptr->scalars[i]); scalar = reduction_fn(scalar, params_ptr->scalars[i]);
@ -992,7 +988,7 @@ struct Sm90RowBroadcast {
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast);
struct SharedStorage { struct SharedStorage {
array_aligned<ElementInput, size<1>(CtaTileShapeMNK{})> smem; array_aligned<ElementInput, size<1>(CtaTileShapeMNK{})> smem;
}; };
@ -1078,8 +1074,8 @@ struct Sm90RowBroadcast {
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE CUTLASS_DEVICE
ConsumerStoreCallbacks( ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
Residue residue_cRow_, Params const& params_) Residue residue_cRow_, Params const& params_)
: tGS_gRow(tGS_gRow_) : tGS_gRow(tGS_gRow_)
@ -1098,8 +1094,8 @@ struct Sm90RowBroadcast {
Tiled_G2S tiled_G2S; Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Residue residue_cRow; // (m, n) Residue residue_cRow; // (m, n)
Params const& params; Params const& params;
@ -1113,7 +1109,7 @@ struct Sm90RowBroadcast {
for (int i = 0; i < size(tGS_gRow_flt); ++i) { for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM, continue; // OOB of SMEM,
} }
if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) { if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) {
tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load
@ -1201,18 +1197,18 @@ struct Sm90RowBroadcast {
} }
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L)); Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem), Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem //// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput>{}, auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput>{},
Layout< Shape<_1, ThreadCount>, Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{}, Stride<_0, _1>>{},
Layout<_1>{}); Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow); Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow); Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord //// G2S: Coord
Tensor tGS_cRow = thr_g2s.partition_S(args.cD); Tensor tGS_cRow = thr_g2s.partition_S(args.cD);
//// S2R: Smem to Reg //// S2R: Smem to Reg
@ -1220,11 +1216,11 @@ struct Sm90RowBroadcast {
Tensor tSR_rRow = make_tensor_like<ElementCompute>(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) Tensor tSR_rRow = make_tensor_like<ElementCompute>(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks( return ConsumerStoreCallbacks(
tGS_gRow, tGS_gRow,
tGS_sRow, tGS_sRow,
tGS_cRow, tiled_g2s, tGS_cRow, tiled_g2s,
tSR_sRow, tSR_sRow,
tSR_rRow, tSR_rRow,
args.residue_cD, args.residue_cD,
params); params);
} }
@ -1378,12 +1374,12 @@ struct Sm90ColBroadcast {
Tensor tCgCol_vec = recast<VecType>(coalesce(tCgCol_flt)); Tensor tCgCol_vec = recast<VecType>(coalesce(tCgCol_flt));
Tensor tCrCol_vec = recast<VecType>(coalesce(tCrCol_flt)); Tensor tCrCol_vec = recast<VecType>(coalesce(tCrCol_flt));
Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int<V>{}))); Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int<V>{})));
auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tCcCol_vec(coords...), residue_tCcCol); }; Tensor tCpCol_vec = cute::lazy::transform(tCcCol_vec, [&](auto const& c){ return elem_less(c, residue_tCcCol); });
copy_if(pred_fn, tCgCol_vec, tCrCol_vec); copy_if(tCpCol_vec, tCgCol_vec, tCrCol_vec);
} }
else { else {
auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tCcCol_flt(coords...), residue_tCcCol); }; Tensor tCpCol_flt = cute::lazy::transform(tCcCol_flt, [&](auto const& c){ return elem_less(c, residue_tCcCol); });
copy_if(pred_fn, tCgCol_flt, tCrCol_flt); copy_if(tCpCol_flt, tCgCol_flt, tCrCol_flt);
} }
constexpr int FrgSize = size(tCrCol_flt); constexpr int FrgSize = size(tCrCol_flt);

View File

@ -412,17 +412,13 @@ struct Sm90AuxStore<
constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){};
constexpr int V = cute::min(Alignment, size(MCL)); constexpr int V = cute::min(Alignment, size(MCL));
Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n);
Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int<V>{})));
Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n)));
Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux)); Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux));
auto pred_fn = [&] (auto const&... coords) { Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int<V>{})));
return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); });
};
copy_if(pred_fn, tC_rAux_vec, tC_gAux_vec); copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec);
} }
}; };

View File

@ -540,6 +540,23 @@ struct HardSwish<Array<T, N> > {
} }
}; };
template <int N>
struct HardSwish<Array<half_t, N> > {
using T = half_t;
static const bool kIsHeavy = false;
static constexpr float kOneSixth = 0.16666667f;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const {
minimum<Array<T, N> > mn;
maximum<Array<T, N> > mx;
multiplies<Array<T, N> > mul;
plus<Array<T, N> > add;
return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(kOneSixth));
}
};
template <typename T> template <typename T>
using ScaledHardSwish = Scale<HardSwish<T>>; using ScaledHardSwish = Scale<HardSwish<T>>;

View File

@ -542,12 +542,8 @@ struct VisitorColBroadcast {
} }
} }
clear(tC_rCol); clear(tC_rCol);
Tensor pred = make_tensor<bool>(shape(tC_gCol)); Tensor tC_pCol = cute::lazy::transform(tC_cCol, [&] (auto const& c) { return get<0>(c) < m; });
CUTLASS_PRAGMA_UNROLL copy_if(tC_pCol, tC_gCol, tC_rCol);
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tC_cCol(i)) < m;
}
copy_if(pred, tC_gCol, tC_rCol);
} }
template <class ElementAccumulator, int FragmentSize> template <class ElementAccumulator, int FragmentSize>

View File

@ -446,7 +446,7 @@ public:
Status Status
construct_graph(bool launch_with_pdl) { construct_graph(bool launch_with_pdl) {
#if ((__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
Status status = Status::kSuccess; Status status = Status::kSuccess;
// Destroy existing graph, if created // Destroy existing graph, if created

View File

@ -47,7 +47,7 @@ void launch_full_barrier(
cudaStream_t stream, cudaStream_t stream,
bool launch_with_pdl) { bool launch_with_pdl) {
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
// Legacy (kernel) launch with PDL // Legacy (kernel) launch with PDL
cudaLaunchAttribute attributes[1]; cudaLaunchAttribute attributes[1];
attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;

View File

@ -268,7 +268,7 @@ struct CollectiveBuilder<
// Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages. // Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages.
static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{}); static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{});
static constexpr uint32_t AccumulatorPipelineStageCount = (AccumulatorNPerCta == 256) ? 1 : 2; static constexpr uint32_t AccumulatorPipelineStageCount = (AccumulatorNPerCta == 256) ? 1 : 2;
static constexpr uint32_t SchedulerPipelineStageCount = 1; static constexpr uint32_t SchedulerPipelineStageCount = 2;
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>; using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;

View File

@ -238,7 +238,7 @@ struct CollectiveBuilder<
static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>; static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>;
// Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler.
static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>; static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>;
static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 1); static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 2);
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
ClusterShape_MNK, ClusterShape_MNK,

View File

@ -51,6 +51,8 @@ struct Sm100DenseGemmTmaUmmaCarveout {
static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage);
// CLC (scheduler) response // CLC (scheduler) response
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
// CLC Throttle pipeline storage
static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
// Tmem dealloc // Tmem dealloc
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
// Tmem ptr storage // Tmem ptr storage
@ -64,6 +66,7 @@ struct Sm100DenseGemmTmaUmmaCarveout {
CLCPipelineStorage + CLCPipelineStorage +
LoadOrderBarrierStorage + LoadOrderBarrierStorage +
TmemDeallocStorage + TmemDeallocStorage +
CLCThrottlePipelineStorage +
CLCResponseStorage + CLCResponseStorage +
TmemBasePtrsStorage + TmemBasePtrsStorage +
TensorMapStorage TensorMapStorage
@ -80,6 +83,8 @@ struct Sm100SparseGemmTmaUmmaCarveout {
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage); static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
// AccumulatorPipeline = PipelineUmmaAsync // AccumulatorPipeline = PipelineUmmaAsync
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage); static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
// CLC Throttle pipeline storage
static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
// Tmem dealloc // Tmem dealloc
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
@ -87,6 +92,7 @@ struct Sm100SparseGemmTmaUmmaCarveout {
cutlass::round_up(LoadOrderBarrierStorage, 16) + cutlass::round_up(LoadOrderBarrierStorage, 16) +
cutlass::round_up(CLCPipelineStorage, 16) + cutlass::round_up(CLCPipelineStorage, 16) +
cutlass::round_up(AccumulatorPipelineStorage, 16) + cutlass::round_up(AccumulatorPipelineStorage, 16) +
cutlass::round_up(CLCThrottlePipelineStorage, 16) +
cutlass::round_up(TmemDeallocStorage, 16), cutlass::round_up(TmemDeallocStorage, 16),
16)); 16));

View File

@ -371,7 +371,7 @@ struct CollectiveBuilder<
// Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages. // Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages.
static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{}); static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{});
static constexpr uint32_t AccumulatorPipelineStageCount = AccumulatorNPerCta > 224 ? 1 : 2; static constexpr uint32_t AccumulatorPipelineStageCount = AccumulatorNPerCta > 224 ? 1 : 2;
static constexpr uint32_t SchedulerPipelineStageCount = 1; static constexpr uint32_t SchedulerPipelineStageCount = 2;
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>; using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;

View File

@ -267,7 +267,7 @@ struct CollectiveBuilder<
static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v<KernelScheduleSm100PtrArrayDenseGemm, BuilderScheduleTag>); static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v<KernelScheduleSm100PtrArrayDenseGemm, BuilderScheduleTag>);
// Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler.
static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>; static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>;
static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 1); static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 2);
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
ClusterShape_MNK, ClusterShape_MNK,

View File

@ -0,0 +1,264 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/collective/builders/sm120_common.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <
int CapacityBytes,
class ElementA,
class ElementB,
class ElementScalar,
class TileShapeMNK,
class ScaleShapeMNK,
class MainloopPipelineStorage,
int stages
>
constexpr int
sm120_compute_stage_count_or_override_blockwise(StageCount<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity.
template <
int CapacityBytes,
class ElementA,
class ElementB,
class ElementScalar,
class TileShapeMNK,
class ScaleShapeMNK,
class MainloopPipelineStorage,
int carveout_bytes
>
constexpr auto
sm120_compute_stage_count_or_override_blockwise(StageCountAutoCarveout<carveout_bytes> stage_count) {
// For F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
constexpr auto scale_bits = cute::sizeof_bits_v<ElementScalar>;
constexpr auto mainloop_pipeline_bytes = sizeof(MainloopPipelineStorage);
constexpr int stage_bytes =
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
static_cast<int>(mainloop_pipeline_bytes);
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ElementA,
class GmemLayoutATagPair,
int AlignmentA,
class ElementB,
class GmemLayoutBTagPair,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class BuilderScheduleTag
>
struct CollectiveBuilder<
arch::Sm120,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATagPair,
AlignmentA,
ElementB,
GmemLayoutBTagPair,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
BuilderScheduleTag,
cute::enable_if_t<
not cute::is_tuple_v<ElementA> && not cute::is_tuple_v<ElementB> &&
not cute::is_complex_v<ElementA> && not cute::is_complex_v<ElementB> &&
cute::is_tuple_v<GmemLayoutATagPair> && cute::is_tuple_v<GmemLayoutBTagPair> &&
(cute::is_base_of_v<KernelScheduleSm120Blockwise, BuilderScheduleTag> ||
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>) &&
detail::sm1xx_gemm_is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, BuilderScheduleTag>()>>
{
static_assert(detail::is_sm10x_f8f6f4_element<ElementA>() && detail::is_sm10x_f8f6f4_element<ElementB>(),
"SM120 TmaWarpSpecialized blockwise scaling builder currently only supports F8F6F4 MMA.");
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
static_assert(cute::is_static_v<ClusterShape_MNK>, "Cluster has to be static");
using GmemLayoutATag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutATagPair{}))>;
using GmemLayoutSFATag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutATagPair{}))>;
using GmemLayoutBTag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutBTagPair{}))>;
using GmemLayoutSFBTag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutBTagPair{}))>;
static_assert(cute::depth(cute::remove_pointer_t<GmemLayoutSFATag>{}) == 2 and
cute::depth(cute::remove_pointer_t<GmemLayoutSFBTag>{}) == 2,
"Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)");
static_assert(size<1, 0>(cute::remove_pointer_t<GmemLayoutSFATag>{}) ==
size<1, 0>(cute::remove_pointer_t<GmemLayoutSFBTag>{}),
"SFA and SFB must have equivalent SF vector sizes along K");
static constexpr cute::UMMA::Major UmmaMajorA = detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB = detail::tag_to_umma_major_B<GmemLayoutBTag>();
static_assert((UmmaMajorA == UMMA::Major::K && UmmaMajorB == UMMA::Major::K), "Only TN layout is supported.");
using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{}));
using PermTileN = decltype(cute::min(size<1>(TileShape_MNK{}), _32{}));
static constexpr bool IsCooperative = !cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_4,_2,_1>>, Layout<Shape<_2,_2,_1>>>;
// Data type used by MMA instruction
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementA>());
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementB>());
static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement<ElementAMma, ElementBMma,
TileShape_MNK, ClusterShape_MNK,
GmemLayoutATag, GmemLayoutBTag, false /*IsSparse*/>(),
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );
// Setup TiledMma
using TiledMma = decltype(cute::make_tiled_mma(
cute::rr_op_selector_sm120<ElementA, ElementB, ElementAccumulator>(),
AtomLayoutMNK{},
Tile<PermTileM, PermTileN, _32>{}
));
// DType check
static constexpr bool UseF8f6f4 = detail::is_sm120_f8f6f4<TiledMma, ElementA, ElementB>();
static_assert(UseF8f6f4, "Non-blockscaled collective builder only supports F8F6F4 MMA.\n");
// Element type
using SmemAllocTypeA = cute::conditional_t<UseF8f6f4, uint8_t, typename TiledMma::ValTypeA>;
using SmemAllocTypeB = cute::conditional_t<UseF8f6f4, uint8_t, typename TiledMma::ValTypeB>;
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector<SmemAllocTypeA, decltype(size<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::sm120_rr_smem_selector<SmemAllocTypeB, decltype(size<2>(TileShape_MNK{}))>());
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using StrideB = cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>;
using StrideSFA = cutlass::gemm::TagToStrideA_t<GmemLayoutSFATag>;
using StrideSFB = cutlass::gemm::TagToStrideB_t<GmemLayoutSFBTag>;
static constexpr int ScaleGranularityM = size<0,0>(cute::remove_pointer_t<GmemLayoutSFATag>{});
static constexpr int ScaleGranularityN = size<0,0>(cute::remove_pointer_t<GmemLayoutSFBTag>{});
static constexpr int ScaleGranularityK = size<1,0>(cute::remove_pointer_t<GmemLayoutSFBTag>{});
static_assert(size<0>(TileShape_MNK{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M.");
static_assert(size<1>(TileShape_MNK{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N.");
static_assert(size<2>(TileShape_MNK{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K.");
using BlockTileScale_M = Int<size<0>(TileShape_MNK{}) / ScaleGranularityM>;
using BlockTileScale_N = Int<size<1>(TileShape_MNK{}) / ScaleGranularityN>;
using BlockTileScale_K = Int<size<2>(TileShape_MNK{}) / ScaleGranularityK>;
using ScaleTileShape = cute::Shape<BlockTileScale_M, BlockTileScale_N, BlockTileScale_K>;
// Setup Stages and DispatchPolicy
using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage;
static constexpr int PipelineStages = detail::sm120_compute_stage_count_or_override_blockwise<
detail::sm120_smem_capacity_bytes, SmemAllocTypeA,
SmemAllocTypeB, ElementAccumulator,
TileShape_MNK, ScaleTileShape, MainloopPipelineStorage>(StageCountType{});
static constexpr uint32_t SchedulerPipelineStageCount = 2;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<cute::remove_pointer_t<StrideA>, StrideA>;
using KernelSchedule = cute::conditional_t<IsGroupedGemmKernel,
// PtrArray
cute::conditional_t<IsCooperative,
KernelPtrArrayTmaWarpSpecializedCooperativeBlockwiseScalingSm120<SchedulerPipelineStageCount>,
KernelPtrArrayTmaWarpSpecializedPingpongBlockwiseScalingSm120<SchedulerPipelineStageCount>>,
// Non-PtrArray
cute::conditional_t<IsCooperative,
KernelTmaWarpSpecializedCooperativeBlockwiseScalingSm120<SchedulerPipelineStageCount>,
KernelTmaWarpSpecializedPingpongBlockwiseScalingSm120<SchedulerPipelineStageCount>>>;
using DispatchPolicy = cute::conditional_t<IsGroupedGemmKernel,
MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling<PipelineStages,
SchedulerPipelineStageCount,
ClusterShape_MNK,
KernelSchedule>,
MainloopSm120TmaWarpSpecializedBlockwiseScaling<PipelineStages,
SchedulerPipelineStageCount,
ClusterShape_MNK,
KernelSchedule>>;
using SmemCopyAtomA = Copy_Atom<decltype(detail::sm120_rr_smem_copy_selector_A<ElementA, ElementB, UseF8f6f4>()), SmemAllocTypeA>;
using SmemCopyAtomB = Copy_Atom<decltype(detail::sm120_rr_smem_copy_selector_B<ElementA, ElementB, UseF8f6f4>()), SmemAllocTypeB>;
using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
cute::tuple<StrideA, StrideSFA>,
ElementB,
cute::tuple<StrideB, StrideSFB>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
};
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -66,6 +66,8 @@ struct CollectiveBuilder<
StageCountType, StageCountType,
BuilderScheduleTag, BuilderScheduleTag,
cute::enable_if_t< cute::enable_if_t<
not cute::is_tuple_v<ElementA> && not cute::is_tuple_v<ElementB> &&
not cute::is_tuple_v<GmemLayoutATag> && not cute::is_tuple_v<GmemLayoutBTag> &&
// Dense Gemm // Dense Gemm
(cute::is_base_of_v<KernelScheduleSm120DenseGemm, BuilderScheduleTag> || (cute::is_base_of_v<KernelScheduleSm120DenseGemm, BuilderScheduleTag> ||
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag> || cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag> ||

View File

@ -50,6 +50,7 @@
#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl"
#include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl"
#include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl"
#include "cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl"
#endif #endif
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -67,6 +67,8 @@
#include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp"
#include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp"
#include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp"
#include "cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp"
#include "cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp"
#endif // !defined(__CUDACC_RTC__) #endif // !defined(__CUDACC_RTC__)

View File

@ -28,10 +28,6 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *
**************************************************************************************************/ **************************************************************************************************/
#pragma once #pragma once
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
@ -51,7 +47,6 @@
#include "cute/arch/cluster_sm90.hpp" #include "cute/arch/cluster_sm90.hpp"
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp" #include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp" #include "cute/numeric/arithmetic_tuple.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -169,7 +164,6 @@ struct CollectiveMma<
using InternalStrideB = cute::remove_pointer_t<StrideB>; using InternalStrideB = cute::remove_pointer_t<StrideB>;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>(); static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>(); static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
@ -210,19 +204,15 @@ struct CollectiveMma<
AtomThrShapeMNK>; AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState; using MainloopPipelineState = typename MainloopPipeline::PipelineState;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape.");
"SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape.");
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomA>, static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape.");
"SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape.");
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomB>, static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
@ -275,8 +265,8 @@ struct CollectiveMma<
using SmemAllocTypeA = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>; using SmemAllocTypeA = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
using SmemAllocTypeB = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>; using SmemAllocTypeB = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
using BitTypeElementA = uint_bit_t<cute::sizeof_bits_v<ElementA>>; using BitTypeElementA = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
using BitTypeElementB = uint_bit_t<cute::sizeof_bits_v<ElementB>>; using BitTypeElementB = cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>;
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>; using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>; using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
@ -308,15 +298,22 @@ struct CollectiveMma<
using TensorMapStorage = typename SharedStorage::TensorMapStorage; using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage; using PipelineStorage = typename SharedStorage::PipelineStorage;
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
static constexpr uint32_t SFTransactionBytes = static constexpr uint32_t SFTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>); cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>);
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
static constexpr uint32_t ABTmaTransactionBytes = static constexpr uint32_t ABTmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>); cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes;
template <class AccTensor, class SfaTensor, class SfbTensor>
struct TmemStorage {
AccTensor accumulators;
SfaTensor tCtSFA;
SfbTensor tCtSFB;
};
// Host side kernel arguments // Host side kernel arguments
struct Arguments { struct Arguments {
ArrayElementA const** ptr_A{nullptr}; ArrayElementA const** ptr_A{nullptr};
@ -401,7 +398,11 @@ struct CollectiveMma<
CUTLASS_DEVICE CUTLASS_DEVICE
CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
: cluster_shape_(cluster_shape) : cluster_shape_(cluster_shape)
, block_rank_in_cluster_(block_rank_in_cluster) { , block_rank_in_cluster_(block_rank_in_cluster)
, layout_SFA_(params.layout_SFA)
, layout_SFB_(params.layout_SFB)
, runtime_data_type_a_(params.runtime_data_type_a)
, runtime_data_type_b_(params.runtime_data_type_b) {
if constexpr (IsDynamicCluster) { if constexpr (IsDynamicCluster) {
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
@ -613,18 +614,48 @@ struct CollectiveMma<
} }
/// Construct A Single Stage's Accumulator Shape /// Construct A Single Stage's Accumulator Shape
CUTLASS_DEVICE auto CUTLASS_DEVICE static
auto
partition_accumulator_shape() { partition_accumulator_shape() {
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
return acc_shape; return acc_shape;
} }
template <class TmemStorage>
CUTLASS_DEVICE static
auto
slice_accumulator(TmemStorage tmem_storage, int stage) {
return tmem_storage.accumulators(_,_,_,stage);
}
template <class FrgEngine, class FrgLayout> template <class EpilogueTile, bool IsOverlappingAccum = false>
CUTLASS_DEVICE auto CUTLASS_DEVICE static
slice_accumulator(cute::Tensor<FrgEngine, FrgLayout> const& accumulators, int stage) { auto
return accumulators(_,_,_,stage); init_tmem_tensors(EpilogueTile epi_tile) {
TiledMma tiled_mma;
auto acc_shape = partition_accumulator_shape();
// ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue.
Tensor accumulators = cutlass::detail::make_sm100_accumulator<AccumulatorPipelineStageCount, IsOverlappingAccum>(
tiled_mma, acc_shape, EpilogueTile{});
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{}));
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
TmemStorage<decltype(accumulators), decltype(tCtSFA), decltype(tCtSFB)> tmem_storage;
tmem_storage.accumulators = accumulators;
tmem_storage.tCtSFA = tCtSFA;
tmem_storage.tCtSFB = tCtSFB;
return tmem_storage;
}
template <class TmemStorage>
CUTLASS_DEVICE static
void
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
tmem_storage.accumulators.data() = tmem_base_addr;
tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators);
tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA);
} }
/// Set up the data needed by this collective for load. /// Set up the data needed by this collective for load.
@ -693,9 +724,9 @@ struct CollectiveMma<
} }
else if constexpr (IsCtaN64) { else if constexpr (IsCtaN64) {
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB));
auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp),
make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
} }
@ -707,7 +738,6 @@ struct CollectiveMma<
Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l)
Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l)
// Partition for this CTA // Partition for this CTA
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
@ -770,17 +800,15 @@ struct CollectiveMma<
} }
/// Set up the data needed by this collective for mma compute. /// Set up the data needed by this collective for mma compute.
template <class FrgEngine, class FrgLayout> template <class TmemStorage>
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
mma_init( mma_init(
Params const& params, TmemStorage tmem_storage,
[[maybe_unused]] cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TensorStorage& shared_tensors) const {
TensorStorage& shared_tensors,
uint32_t const tmem_offset) const {
// Allocate "fragments/descriptors" for A and B matrices // Allocate "fragments/descriptors" for A and B matrices
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Allocate "fragments/descriptors" for A and B matrices // Allocate "fragments/descriptors" for A and B matrices
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
@ -792,13 +820,8 @@ struct CollectiveMma<
// //
// Scale Factor // Scale Factor
// //
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{})); Tensor tCtSFA = tmem_storage.tCtSFA;
// Set tCtSFA and tCtSFB start addresses. Only update the TMEM column address by masking the address with 0x000001FF. Tensor tCtSFB = tmem_storage.tCtSFB;
// TMEM allocations for SFA and SFB will always start at DP 0.
tCtSFA.data() = tmem_offset;
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA);
// Setup smem descriptors for UTCCP // Setup smem descriptors for UTCCP
Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{});
Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
@ -831,8 +854,10 @@ struct CollectiveMma<
TiledMma tiled_mma; TiledMma tiled_mma;
if constexpr (IsRuntimeDataType) { if constexpr (IsRuntimeDataType) {
tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; // Update instruction descriptor according to runtime argument.
tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe.
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
} }
return cute::make_tuple( return cute::make_tuple(
@ -997,45 +1022,52 @@ struct CollectiveMma<
// //
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
if (k_tile_count > 0) { // first iteraion if constexpr (IsOverlappingAccum) {
// WAIT on mainloop_pipe_consumer_state until its data are available // first iteration manual unroll for tmem overlap kernel
// (phase bit flips from mainloop_pipe_consumer_state.phase() value) if (k_tile_count > 0) {
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); // WAIT on mainloop_pipe_consumer_state until its data are available
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
// Compute on k_tile // Compute on k_tile
int read_stage = mainloop_pipe_consumer_state.index(); int read_stage = mainloop_pipe_consumer_state.index();
// Save current mainlop pipeline read state // Save current mainlop pipeline read state
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
// Advance mainloop_pipe // Advance mainloop_pipe
++mainloop_pipe_consumer_state; ++mainloop_pipe_consumer_state;
--k_tile_count; --k_tile_count;
skip_wait = k_tile_count <= 0; skip_wait = k_tile_count <= 0;
// Peek at next iteration // Peek at next iteration
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
if (cute::elect_one_sync()) { if (cute::elect_one_sync()) {
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
} }
if constexpr (IsOverlappingAccum) { // Wait for tmem accumulator buffer to become empty with a flipped phase
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
}
// Unroll the K mode manually so we can set scale C to 1 // Unroll the K mode manually so we can set scale C to 1
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N) // (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma.with(tiled_mma.accumulate_, cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
tCtSFA(_,_,k_block), tCtSFA(_,_,k_block),
tCtSFB_mma(_,_,k_block)), tCtSFB_mma(_,_,k_block)),
tCrA(_,_,k_block,read_stage), tCrA(_,_,k_block,read_stage),
tCrB(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage),
accumulators); accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One; tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
} }
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); }
else {
// Wait for tmem accumulator buffer to become empty with a flipped phase
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
} }
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
@ -1073,6 +1105,7 @@ struct CollectiveMma<
accumulators); accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One; tiled_mma.accumulate_ = UMMA::ScaleOut::One;
} }
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
} }
@ -1273,6 +1306,11 @@ protected:
typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr};
typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr};
LayoutSFA layout_SFA_;
LayoutSFB layout_SFB_;
RuntimeDataTypeA runtime_data_type_a_{};
RuntimeDataTypeB runtime_data_type_b_{};
ClusterShape cluster_shape_; ClusterShape cluster_shape_;
uint32_t block_rank_in_cluster_; uint32_t block_rank_in_cluster_;
}; };

View File

@ -47,7 +47,6 @@
#include "cute/arch/cluster_sm90.hpp" #include "cute/arch/cluster_sm90.hpp"
#include "cute/atom/mma_atom.hpp" #include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp" #include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp" #include "cute/numeric/arithmetic_tuple.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -123,7 +122,7 @@ struct CollectiveMma<
"Static cluster shape used: TileShape should be evenly divided by TiledMma"); "Static cluster shape used: TileShape should be evenly divided by TiledMma");
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or
shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256,
"Cta N should be one of 64/128/192/256"); "Cta N should be one of 64/128/192/256");
@ -726,9 +725,9 @@ struct CollectiveMma<
} }
else if constexpr (IsCtaN64) { else if constexpr (IsCtaN64) {
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_));
auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp),
make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
} }

Some files were not shown because too many files have changed in this diff Show More