diff --git a/CHANGELOG.md b/CHANGELOG.md index fc269c8b..9f423fb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,31 @@ # NVIDIA CUTLASS Changelog + +## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-03-20) + +* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API: + - Collective mainloops that target for: + * [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp) + * [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp) + - New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture: + - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). + - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). + - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). +* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. +* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: + - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. + - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. + - Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. + - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search in CUTLASS: + - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. + - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. + - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + ## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25) * Support for new CuTe building blocks specifically for Blackwell SM100 architecture: @@ -538,4 +564,3 @@ SPDX-License-Identifier: BSD-3-Clause OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ``` - diff --git a/CMakeLists.txt b/CMakeLists.txt index 65821237..1e6f298e 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,6 +102,8 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON) list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) +list(APPEND CUTLASS_CUDA_NVCC_FLAGS -ftemplate-backtrace-limit=0) + if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE) endif() @@ -173,7 +175,7 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") @@ -441,7 +443,7 @@ if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR}) - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories. + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v -objtemp) # --keep-dir may not work with nvcc for some directories. list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR}) endif() @@ -468,6 +470,13 @@ if(UNIX) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing) endif() +# Known ctk11.4 issue (fixed later) +# Also see https://stackoverflow.com/questions/64523302/cuda-missing-return-statement-at-end-of-non-void-function-in-constexpr-if-fun +if (CUDA_VERSION VERSION_LESS 11.5.0) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcudafe "--diag_suppress=implicit_return_from_non_void_function" ) + message("CUDA_VERSION check pass ${CUDA_VERSION}") +endif() + # Don't leak lineinfo in release builds if (NOT CMAKE_BUILD_TYPE MATCHES "Release") list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt) @@ -1045,6 +1054,7 @@ function(cutlass_generate_profiler_tests NAME) string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}") string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}") string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "swizzle_size=" "" TEST_NAME "${TEST_NAME}") string(REGEX REPLACE "verification_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}") string(REGEX REPLACE "warmup_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}") string(REGEX REPLACE "profiling_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}") diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 843ed365..46506007 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -128,3 +128,35 @@ Bryce Lelbach
Joel McCormack
Kyrylo Perelygin
Sean Treichler
+ +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index c91fc06a..176b42e4 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -2,10 +2,14 @@ ## 2025 +- ["Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts"](https://arxiv.org/abs/2502.19811). Shulai Zhang, Ningxin Zheng, Haibin Lin, Ziheng Jiang, Wenlei Bao, Chengquan Jiang, Qi Hou, Weihao Cui, Size Zheng, Li-Wen Chang, Quan Chen, Xin Liu. _arXiv_, February 2025. + - ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025. ## 2024 +- ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024. + - ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024. - ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024. @@ -64,3 +68,35 @@ "](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020. - ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020. + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/README.md b/README.md index ada18b39..77a81620 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.8.0 +# CUTLASS 3.9.0 -_CUTLASS 3.8.0 - January 2025_ +_CUTLASS 3.9.0 - March 2025_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -38,65 +38,30 @@ See the [functionality docs](./media/docs/functionality.md) for a more comprehen list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU architecture. -# What's New in CUTLASS 3.8 +# What's New in CUTLASS 3.9 -CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture. -For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8. - -* Support for new CuTe building blocks specifically for Blackwell SM100 architecture: - - [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms. - - Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms. - - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale. - - Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe. - - [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms. - - Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms. -* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture: - - Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h) - - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). - - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). - - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. -* Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that - + Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture. - + Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators. - + Support stream-K load balancing for all kernel types everywhere via composable scheduler support. - - Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for - * [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp) - * [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp) - * [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp) - * [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp) - - Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad. - - New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. - - [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions](). -* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification. - - Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes. - - Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors. - - Support for mixed input GEMM kernels on Hopper in the profiler. -* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels. -* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details). -* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture: - - [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API. - - GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell. - - Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores: - + [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) - + [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) - + [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) - - GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy. - - [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu). - - Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu). - - Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu). - - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. - - A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations. -* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture: - - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. - - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). -* Documentation updates: - - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) - - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). +* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API: + - Collective mainloops that target for: + * [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp) + * [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp) + - New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture: + - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). + - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). + - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). +* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. +* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: + - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. + - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. + - Support for [grouped GEMM with blockwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. + - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. +* Added support for enhanced kernel performance search in CUTLASS: + - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. + - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. + - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. diff --git a/customConfigs.cmake b/customConfigs.cmake index e39212db..d98fe6c5 100644 --- a/customConfigs.cmake +++ b/customConfigs.cmake @@ -65,10 +65,10 @@ endfunction() if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS) - set(PROFILER_ARCH_LIST 100a) + set(PROFILER_ARCH_LIST 100a 101a 120a) foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS) if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST)) - message(FATAL_ERROR "Only SM100a compute capability is supported with profiler-based unit tests") + message(FATAL_ERROR "Only SM100a/101a/120a compute capability is supported with profiler-based unit tests") endif() endforeach() diff --git a/examples/13_two_tensor_op_fusion/README.md b/examples/13_two_tensor_op_fusion/README.md index 9fa8297d..ed9b2727 100644 --- a/examples/13_two_tensor_op_fusion/README.md +++ b/examples/13_two_tensor_op_fusion/README.md @@ -115,4 +115,3 @@ SPDX-License-Identifier: BSD-3-Clause OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ``` - diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md index c670e340..02222f8e 100644 --- a/examples/40_cutlass_py/README.md +++ b/examples/40_cutlass_py/README.md @@ -2,3 +2,35 @@ This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface. For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory. + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/40_cutlass_py/customizable/README.md b/examples/40_cutlass_py/customizable/README.md index e8aeee9e..b6863fb0 100644 --- a/examples/40_cutlass_py/customizable/README.md +++ b/examples/40_cutlass_py/customizable/README.md @@ -165,3 +165,35 @@ Example 7: GELU ```python python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu ``` + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index ca64c901..7c61e75c 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -41,3 +41,35 @@ We are currently optimizing the following cases: * Optimizations for memory bound cases. * Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/59_ampere_gather_scatter_conv/README.md b/examples/59_ampere_gather_scatter_conv/README.md index 4aac0536..2f3d8b83 100644 --- a/examples/59_ampere_gather_scatter_conv/README.md +++ b/examples/59_ampere_gather_scatter_conv/README.md @@ -207,3 +207,35 @@ With this in mind, this example kernel has the following limitations: - This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s - Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape - It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt index c9f638e6..72f59476 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt +++ b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt @@ -26,11 +26,13 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -include_directories( - . -) +set(TEST_PREFETCH_CASE --m=8192 --n=64 --k=8192 --iterations=0) cutlass_example_add_executable( 63_hopper_gemm_with_weight_prefetch 63_hopper_gemm_with_weight_prefetch.cu - ) + TEST_COMMAND_OPTIONS + TEST_PREFETCH_CASE +) + +target_include_directories(63_hopper_gemm_with_weight_prefetch PUBLIC .) diff --git a/examples/63_hopper_gemm_with_weight_prefetch/README.md b/examples/63_hopper_gemm_with_weight_prefetch/README.md index 5dac1cc6..3fd615ff 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/README.md +++ b/examples/63_hopper_gemm_with_weight_prefetch/README.md @@ -74,9 +74,40 @@ echo "Overlap ratio of 0.8, prefetch ratio of 0.7" However, note that the example still runs a single GEMM, and most of the performance improvement is expected in end to end applications. - ## Limitations * The parameter defaults are typically not good choices, especially `prefetch_ratio`. When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a memory barrier before issuing every single TMA load, and in many cases this will slow down prefetching to the point of being almost ineffective. + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp index 0c54bc05..73655ad2 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -362,11 +362,11 @@ public: using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier; auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier( blockDim.x * blockDim.y * blockDim.z, - /*reserved_named_barriers_*/ 14); + /*id*/ 0); // Prefetcher warp doesn't arrive on this barrier. auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier( blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp, - /*reserved_named_barriers_*/ 15); + /*id*/ 1); if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { __syncwarp(); diff --git a/examples/65_distributed_gemm/README.md b/examples/65_distributed_gemm/README.md index fc53e6bf..e3c48a9d 100644 --- a/examples/65_distributed_gemm/README.md +++ b/examples/65_distributed_gemm/README.md @@ -62,3 +62,36 @@ procedure is the same, simply modify the following line in the example: ```cpp using TP = _8; ``` + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` + diff --git a/examples/65_distributed_gemm/REQUIREMENTS.md b/examples/65_distributed_gemm/REQUIREMENTS.md index cc0d5632..4b8cca3b 100644 --- a/examples/65_distributed_gemm/REQUIREMENTS.md +++ b/examples/65_distributed_gemm/REQUIREMENTS.md @@ -84,3 +84,35 @@ GPU5 OK OK OK OK OK X OK OK GPU6 OK OK OK OK OK OK X OK GPU7 OK OK OK OK OK OK OK X ``` + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index e4afcb30..1c21678f 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) // C matrix configuration -using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using ElementC = float; // Element type for C and D matrix operands using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) @@ -251,93 +251,93 @@ struct Result ///////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to initialize a block of device data - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { - if (dist_kind == cutlass::Distribution::Uniform) { + if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::AllZeros) { - cutlass::reference::host::TensorFill(view); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - throw std::runtime_error("Not implementated."); + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; } - return true; + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, bits_input); } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} /// Helper to initialize a block of device data (scale_tensors) - template - bool initialize_scale_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { - if (dist_kind == cutlass::Distribution::Uniform) { + if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; + double scope_max, scope_min; - scope_min = -1; - scope_max = 1; + scope_min = -1; + scope_max = 1; - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::AllZeros) { - cutlass::reference::host::TensorFill(view); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - throw std::runtime_error("Not implementated."); - } - - return true; + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min); } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { @@ -438,14 +438,18 @@ void initialize(const Options &options) { if (IsDFp8 && options.save_amax) { abs_max_D.resize(cutlass::make_Coord(1)); + initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0); abs_max_D.sync_device(); reference_abs_max_D.resize(cutlass::make_Coord(1)); + initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0); } if (IsAuxFp8 && options.save_aux && options.save_amax) { abs_max_aux.resize(cutlass::make_Coord(1)); + initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0); abs_max_aux.sync_device(); reference_abs_max_aux.resize(cutlass::make_Coord(1)); + initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0); } } @@ -517,10 +521,9 @@ bool verify(const Options &options) { // Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); - auto blockscale_m = cute::get<0>(blockscale_shape); - auto blockscale_n = cute::get<1>(blockscale_shape); - auto blockscale_k = cute::get<2>(blockscale_shape); + auto blockscale_m = ceil_div(options.m, get<0>(TileShape{})); + auto blockscale_n = ceil_div(options.n, get<1>(TileShape{})); + auto blockscale_k = ceil_div(options.k, get<2>(TileShape{})); // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), @@ -608,29 +611,40 @@ bool verify(const Options &options) { cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); // compare_reference + bool passed = true; tensor_D.sync_host(); - bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor)); + double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view()); + double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view()); + double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view()); + std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl; - if (false) { - std::cout << "tensor_ref_D.host_view() {" << std::endl - << tensor_ref_D.host_view() << std::endl - << "}" << std::endl; - std::cout << "tensor_D.host_view() {" << std::endl - << tensor_D.host_view() << std::endl - << "}" << std::endl; - } +#if 0 + std::cout << "tensor_ref_D.host_view() {" << std::endl + << tensor_ref_D.host_view() << std::endl + << "}" << std::endl; + std::cout << "tensor_D.host_view() {" << std::endl + << tensor_D.host_view() << std::endl + << "}" << std::endl; +#endif if (IsDFp8 && options.save_amax) { abs_max_D.sync_host(); - passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0)); + std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl; + passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor)); } if (options.save_aux) { tensor_aux.sync_host(); - passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view()); + passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor)); + mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view()); + mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view()); + max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view()); + std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl; if (IsAuxFp8 && options.save_amax) { abs_max_aux.sync_host(); - passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0)); + std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl; + passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor)); } } @@ -671,10 +685,9 @@ int run(Options &options) std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; } - - // if (!result.passed) { - // exit(-1); - // } + else { + result.passed = true; + } // Run profiling loop if (options.iterations > 0) @@ -707,7 +720,7 @@ int run(Options &options) std::cout << " GFLOPS: " << result.gflops << std::endl; } - return 0; + return result.passed; } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) @@ -753,7 +766,9 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - run(options); + bool passed = run(options); + if (!passed) + return -1; #endif return 0; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index 03945764..b7cdb00a 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) // C matrix configuration -using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using ElementC = float; // Element type for C and D matrix operands using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) @@ -303,93 +303,93 @@ struct Result ///////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to initialize a block of device data - template - bool initialize_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { - if (dist_kind == cutlass::Distribution::Uniform) { + if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; - if (bits_input == 1) { - scope_max = 2; - scope_min = 0; - } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; - } else if (bits_output == 16) { - scope_max = 5; - scope_min = -5; - } else { - scope_max = 8; - scope_min = -8; - } - - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::AllZeros) { - cutlass::reference::host::TensorFill(view); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - throw std::runtime_error("Not implementated."); + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; } - return true; + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, bits_input); } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} /// Helper to initialize a block of device data (scale_tensors) - template - bool initialize_scale_tensor( - cutlass::TensorView view, - cutlass::Distribution::Kind dist_kind, - uint64_t seed) { +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { - if (dist_kind == cutlass::Distribution::Uniform) { + if (dist_kind == cutlass::Distribution::Uniform) { - double scope_max, scope_min; + double scope_max, scope_min; - scope_min = -1; - scope_max = 1; + scope_min = -1; + scope_max = 1; - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); - } - else if (dist_kind == cutlass::Distribution::AllZeros) { - cutlass::reference::host::TensorFill(view); - } - else if (dist_kind == cutlass::Distribution::Identity) { - - cutlass::reference::host::TensorFillIdentity(view); - } - else if (dist_kind == cutlass::Distribution::Gaussian) { - - cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); - } - else if (dist_kind == cutlass::Distribution::Sequential) { - cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); - } - else { - throw std::runtime_error("Not implementated."); - } - - return true; + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min); } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} /// Initialize operands to be used in the GEMM and reference GEMM template @@ -403,11 +403,9 @@ void initialize(const Options &options) { assert(options.n % ScaleGranularityN == 0); // Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape - auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); - auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM; - auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN; - auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = ceil_div(options.m, ScaleGranularityM); + auto groupscale_n = ceil_div(options.n, ScaleGranularityN); + auto blockscale_k = ceil_div(options.k, cute::get<2>(TileShape{})); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); @@ -582,13 +580,11 @@ bool verify(const Options &options, const int ScaleMsPerTile const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile; // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape - auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{}))); - auto blockscale_m = cute::get<0>(blockscale_shape); - auto blockscale_n = cute::get<1>(blockscale_shape); - auto blockscale_k = cute::get<2>(blockscale_shape); - auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM; - auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN; + auto blockscale_m = ceil_div(options.m, get<0>(TileShape_{})); + auto blockscale_n = ceil_div(options.n, get<1>(TileShape_{})); + auto blockscale_k = ceil_div(options.k, get<2>(TileShape_{})); + auto groupscale_m = ceil_div(options.m, ScaleGranularityM); + auto groupscale_n = ceil_div(options.n, ScaleGranularityN); // Create instantiation for device reference gemm kernel auto A = cute::make_tensor(tensor_A.host_data(), @@ -676,8 +672,13 @@ bool verify(const Options &options, const int ScaleMsPerTile cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); // compare_reference + bool passed = true; tensor_D.sync_host(); - bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor)); + double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view()); + double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view()); + double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view()); + std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl; #if 0 std::cout << "tensor_ref_D.host_view() {" << std::endl @@ -690,15 +691,21 @@ bool verify(const Options &options, const int ScaleMsPerTile if (IsDFp8 && options.save_amax) { abs_max_D.sync_host(); - passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0)); + std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl; + passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor)); } if (options.save_aux) { tensor_aux.sync_host(); - passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view()); + passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor)); + mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view()); + mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view()); + max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view()); + std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl; if (IsAuxFp8 && options.save_amax) { abs_max_aux.sync_host(); - passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0)); + std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl; + passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor)); } } @@ -716,29 +723,29 @@ int run(Options &options) const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; bool skip = false; - - if (options.m % ScaleGranularityM != 0) { - std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl; - skip = true; - } - - if (options.n % ScaleGranularityN != 0) { - std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl; - skip = true; - } - - if (options.k % size<2>(TileShape{}) != 0) { - std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl; - skip = true; - } - - if (!skip) std::cout << "Running: " << std::endl; std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; - if (skip) return -1; + + if (options.m < ScaleGranularityM) { + std::cout << " Skippig (m size: " << options.m << " less than ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl; + skip = true; + } + + if (options.n < ScaleGranularityN) { + std::cout << " Skippig (n size: " << options.n << " less than ScaleGranularityN: " << ScaleGranularityN << "):" << std::endl; + skip = true; + } + + if (options.k < size<2>(TileShape{})) { + std::cout << " Skippig (k size: " << options.k << " less than TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl; + skip = true; + } + + if (!skip) std::cout << " Running... " << std::endl; + else return -1; initialize(options); @@ -770,17 +777,17 @@ int run(Options &options) std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; } - - if (!result.passed) { - exit(-1); + else { + result.passed = true; } // Run profiling loop if (options.iterations > 0) { GpuTimer timer; - timer.start(); - for (int iter = 0; iter < options.iterations; ++iter) { + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + if (iter == options.warmup) + timer.start(); CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); CUTLASS_CHECK(gemm.run()); } @@ -806,7 +813,7 @@ int run(Options &options) fflush(stdout); } - return 0; + return result.passed; } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) @@ -852,27 +859,31 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + bool passed = true; std::cout << "Basic split-K GEMM kernel" << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; std::cout << std::endl; std::cout << "StreamK GEMM kernel" << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; - run(options); + passed &= run(options); std::cout << std::endl; + + if (!passed) + return -1; #endif return 0; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp index 23f05ada..85aff756 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -46,6 +46,8 @@ struct Options { int m = 1024, n = 512, k = 1024, l = 1; RasterOrderOptions raster; int swizzle; + float epsilon = 0.02f; + float non_zero_floor = 1.f; // Parses the command line void parse(int argc, char const **args) { @@ -73,6 +75,8 @@ struct Options { cmd.get_cmd_line_argument("warmup", warmup); cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("verify", verify); + cmd.get_cmd_line_argument("epsilon", epsilon); + cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor); char raster_char; cmd.get_cmd_line_argument("raster", raster_char); @@ -113,7 +117,10 @@ struct Options { << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" << " --swizzle= CTA Rasterization swizzle\n\n" - << " --iterations= Number of profiling iterations to perform.\n\n"; + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --verify= Verify the results.\n\n" + << " --epsilon= The epsilon value for comparing the results.\n\n" + << " --non-zero-floor= The none zero floor for comparing the results.\n\n"; out << "\n\nExamples:\n\n" diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h index 6bb593bd..0bf90a41 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -221,9 +221,9 @@ void gett_mainloop( const int N = cute::size<0>(mainloop_params.B.layout()); const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA); const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB); - assert(ScaleGranularityM && M % ScaleGranularityM == 0 + assert(ScaleGranularityM && M % ScaleGranularityM == 0 && "ScaleGranularityM must divide M"); - assert(ScaleGranularityN && N % ScaleGranularityN == 0 + assert(ScaleGranularityN && N % ScaleGranularityN == 0 && "ScaleGranularityN must divide N"); cute::Tensor blockscale_A = domain_offset( diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md index 272d36e5..f4d71ea3 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md @@ -12,3 +12,35 @@ Note that in Example 55, the argument `--g` is used to determine the block scale ## Upcoming features Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed. + +## Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu index 3cee6caf..19d6b89d 100644 --- a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu +++ b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu @@ -194,12 +194,14 @@ struct Options { float alpha, beta; int iterations; int m, n, k; + int swizzle; Options(): help(false), m(8192), n(8192), k(8192), alpha(1.f), beta(0.f), - iterations(10) + iterations(10), + swizzle(0) { } // Parses the command line @@ -217,6 +219,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); } /// Prints the usage statement. @@ -231,6 +234,7 @@ struct Options { << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" + << " --swizzle= Cluster rasterization swizzle\n\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out @@ -331,6 +335,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; + arguments.scheduler.max_swizzle_size = options.swizzle; + return arguments; } diff --git a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu index 69a36310..d476ce00 100644 --- a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu +++ b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu @@ -231,6 +231,7 @@ struct Options { bool save_amax = true; int iterations = 1000; int m = 1024, n = 512, k = 1024, l = 1; + int swizzle = 0; // Parses the command line void parse(int argc, char const **args) { @@ -256,6 +257,7 @@ struct Options { cmd.get_cmd_line_argument("save_aux", save_aux, true); cmd.get_cmd_line_argument("save_amax", save_amax, true); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); } /// Prints the usage statement. @@ -271,6 +273,7 @@ struct Options { << " --l= Sets the l extent (batch) of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n" << " --scale_a= Scaling factor for A\n" << " --scale_b= Scaling factor for B\n" << " --scale_c= Scaling factor for C\n" @@ -476,6 +479,8 @@ typename Gemm::Arguments args_from_options(const Options &options) fusion_args.amax_D_ptr = abs_max_D.device_data(); } + arguments.scheduler.max_swizzle_size = options.swizzle; + return arguments; } diff --git a/examples/70_blackwell_gemm/CMakeLists.txt b/examples/70_blackwell_gemm/CMakeLists.txt index cb401e3a..0ac1687d 100644 --- a/examples/70_blackwell_gemm/CMakeLists.txt +++ b/examples/70_blackwell_gemm/CMakeLists.txt @@ -28,14 +28,29 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if (CUTLASS_NVCC_ARCHS MATCHES 100a) +set(TEST_SWIZZLE_1 --swizzle=1) +set(TEST_SWIZZLE_2 --swizzle=2) +set(TEST_SWIZZLE_5 --swizzle=5) +set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384) + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") cutlass_example_add_executable( 70_blackwell_fp16_gemm 70_blackwell_fp16_gemm.cu -) + TEST_COMMAND_OPTIONS + TEST_SWIZZLE_1 + TEST_SWIZZLE_2 + TEST_SWIZZLE_5 + TEST_SWIZZLE_5_UNEVEN +) cutlass_example_add_executable( 70_blackwell_fp8_gemm 70_blackwell_fp8_gemm.cu + TEST_COMMAND_OPTIONS + TEST_SWIZZLE_1 + TEST_SWIZZLE_2 + TEST_SWIZZLE_5 + TEST_SWIZZLE_5_UNEVEN ) endif() diff --git a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu index 427af254..f911262f 100644 --- a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu +++ b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu @@ -74,12 +74,14 @@ struct Options { int m, n, k, l; float alpha, beta; + int swizzle; Options(): help(false), error(false), m(2048), n(2048), k(2048), l(1), - alpha(1.f), beta(0.f) + alpha(1.f), beta(0.f), + swizzle(0) { } // Parses the command line @@ -97,6 +99,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("swizzle", swizzle); } /// Prints the usage statement. @@ -112,7 +115,8 @@ struct Options { << " --k= Sets the K extent of the GEMM\n" << " --l= Sets the L extent (batch count) of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n"; + << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n\n"; return out; } @@ -352,6 +356,8 @@ struct ExampleRunner { hw_info }; + arguments.scheduler.max_swizzle_size = options.swizzle; + // See example 48 for details on custom EVT construction if constexpr (UseCustomEVT) { arguments.epilogue.thread = diff --git a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu index f7e12fbf..f729b43d 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu @@ -211,12 +211,14 @@ struct Options { float alpha, beta; int iterations; int m, n, k; + int swizzle = 0; Options(): help(false), m(1024), n(1024), k(1024), alpha(1.f), beta(0.f), - iterations(10) + iterations(10), + swizzle(0) { } // Parses the command line @@ -234,6 +236,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); } /// Prints the usage statement. @@ -247,7 +250,8 @@ struct Options { << " --n= Sets the N extent of the GEMM\n" << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" + << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out << "\n\nExamples:\n\n" @@ -333,7 +337,7 @@ bool initialize_block( void initialize(const Options &options) { using namespace cute; // For SFA and SFB tensors layouts - using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); @@ -344,8 +348,8 @@ void initialize(const Options &options) { layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); - layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); - layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); block_A.reset(cutlass::make_Coord(size(layout_A))); block_B.reset(cutlass::make_Coord(size(layout_B))); @@ -387,6 +391,7 @@ typename Gemm::Arguments args_from_options(const Options &options) } }; + arguments.scheduler.max_swizzle_size = options.swizzle; return arguments; } diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu index 2719cab9..75d3437d 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -177,7 +177,7 @@ using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); using FusionOp = typename Gemm::EpilogueOutputOp; constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; -using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig; +using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig; using LayoutSFD = typename SfdOutputCfg::LayoutSF; // @@ -240,12 +240,14 @@ struct Options { float alpha, beta; int iterations; int m, n, k; + int swizzle = 0; Options(): help(false), m(1024), n(1024), k(1024), alpha(1.f), beta(0.f), - iterations(10) + iterations(10), + swizzle(0) { } // Parses the command line @@ -263,6 +265,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); } /// Prints the usage statement. @@ -276,7 +279,8 @@ struct Options { << " --n= Sets the N extent of the GEMM\n" << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" + << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out << "\n\nExamples:\n\n" @@ -362,9 +366,9 @@ bool initialize_block( void initialize(const Options &options) { using namespace cute; // For SFA and SFB tensors layouts - using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; // For SFD tensor layout - using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); @@ -375,8 +379,8 @@ void initialize(const Options &options) { layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); - layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); - layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1)); block_A.reset(cutlass::make_Coord(size(layout_A))); @@ -432,6 +436,7 @@ typename Gemm::Arguments args_from_options(const Options &options) arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); } + arguments.scheduler.max_swizzle_size = options.swizzle; return arguments; } diff --git a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu index 2784d050..1d6c1f3c 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu @@ -212,12 +212,14 @@ struct Options { float alpha, beta; int iterations; int m, n, k; + int swizzle = 0; Options(): help(false), m(1024), n(1024), k(1024), alpha(1.f), beta(0.f), - iterations(10) + iterations(10), + swizzle(0) { } // Parses the command line @@ -235,6 +237,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); } /// Prints the usage statement. @@ -248,7 +251,8 @@ struct Options { << " --n= Sets the N extent of the GEMM\n" << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" + << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out << "\n\nExamples:\n\n" @@ -334,7 +338,7 @@ bool initialize_block( void initialize(const Options &options) { using namespace cute; // For SFA and SFB tensors layouts - using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); @@ -345,8 +349,8 @@ void initialize(const Options &options) { layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); - layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); - layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); block_A.reset(cutlass::make_Coord(size(layout_A))); block_B.reset(cutlass::make_Coord(size(layout_B))); @@ -388,6 +392,7 @@ typename Gemm::Arguments args_from_options(const Options &options) } }; + arguments.scheduler.max_swizzle_size = options.swizzle; return arguments; } diff --git a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu index 19c4efd1..67b82a6e 100644 --- a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu +++ b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu @@ -214,7 +214,8 @@ struct Options { int iterations; int m, n, k; int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n; - + int swizzle = 0; + Options(): help(false), m(4096), n(4096), k(4096), @@ -223,7 +224,8 @@ struct Options { preferred_cluster_m(4), preferred_cluster_n(4), fallback_cluster_m(2), - fallback_cluster_n(1) + fallback_cluster_n(1), + swizzle(0) { } // Parses the command line @@ -245,6 +247,7 @@ struct Options { cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4); cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2); cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1); + cmd.get_cmd_line_argument("swizzle", swizzle); if (!validate_cluster_shape()){ std::cout << "--Invalid cluster shapes" << std::endl; @@ -265,6 +268,7 @@ struct Options { << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n" << " --preferred_cluster_m= Sets the M extent of preferred cluster shape\n" << " --preferred_cluster_n= Sets the N extent of preferred cluster shape\n" << " --fallback_cluster_m= Sets the M extent of fallback cluster shape\n" @@ -384,7 +388,8 @@ typename Gemm::Arguments args_from_options(const Options &options) { arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1); arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1); - + + arguments.scheduler.max_swizzle_size = options.swizzle; return arguments; } diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu index 1d8db6e2..ad563a4b 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu @@ -242,6 +242,7 @@ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTil struct Options { bool help = false; + bool use_pdl = false; float alpha = FLT_MAX; float beta = FLT_MAX; @@ -264,6 +265,9 @@ struct Options { help = true; return; } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } cmd.get_cmd_line_argument("m", m); cmd.get_cmd_line_argument("n", n); @@ -387,7 +391,8 @@ struct Options { << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" << " --iterations= Number of profiling iterations to perform\n\n" << " --benchmark= Executes a benchmark problem size\n" - << " --max_sm_count= Run kernels using only these number of SMs\n"; + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; out << "\n\nExamples:\n\n" @@ -711,7 +716,7 @@ int run(Options &options, bool host_problem_shapes_available = true) CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); // Correctness / Warmup iteration - CUTLASS_CHECK(gemm.run()); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; @@ -730,7 +735,7 @@ int run(Options &options, bool host_problem_shapes_available = true) timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); - CUTLASS_CHECK(gemm.run()); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); } timer.stop(); diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index ee697135..d5814c0a 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -219,14 +219,14 @@ using StrideD = typename Gemm::GemmKernel::InternalStrideD; using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; -using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; -using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig< +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< OutputSFVectorSize, cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN >; -using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom; -using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF; +using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; +using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF; // Host-side allocations std::vector stride_A_host; @@ -305,6 +305,7 @@ struct Options { bool help = false; bool verification = true; + bool use_pdl = false; float alpha = FLT_MAX; float beta = FLT_MAX; @@ -328,9 +329,12 @@ struct Options { help = true; return; } - if (cmd.check_cmd_line_flag("no-verif")) { + if (cmd.check_cmd_line_flag("no_verif")) { verification = false; } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } cmd.get_cmd_line_argument("m", m); cmd.get_cmd_line_argument("n", n); @@ -457,7 +461,8 @@ struct Options { << " --iterations= Number of profiling iterations to perform\n\n" << " --benchmark= Executes a benchmark problem size\n" << " --max_sm_count= Run kernels using only these number of SMs\n" - << " --no-verif Do not run (host-side) verification kernels\n"; + << " --no_verif Do not run (host-side) verification kernels\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; out << "\n\nExamples:\n\n" @@ -554,9 +559,9 @@ void allocate(const Options &options) { auto layout_B = make_layout(make_shape(N, K, 1), stride_B); auto layout_C = make_layout(make_shape(M, N, 1), stride_C); auto layout_D = make_layout(make_shape(M, N, 1), stride_D); - auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); stride_A_host.push_back(stride_A); stride_B_host.push_back(stride_B); @@ -775,9 +780,9 @@ bool verify(const Options &options) { auto layout_B = make_layout(make_shape(N, K, 1), stride_B); auto layout_C = make_layout(make_shape(M, N, 1), stride_C); auto layout_D = make_layout(make_shape(M, N, 1), stride_D); - auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); // Create the arguments for host reference implementation Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A); @@ -845,7 +850,7 @@ int run(Options &options, bool host_problem_shapes_available = true) CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); // Correctness / Warmup iteration - CUTLASS_CHECK(gemm.run()); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); cudaDeviceSynchronize(); @@ -870,7 +875,7 @@ int run(Options &options, bool host_problem_shapes_available = true) timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); - CUTLASS_CHECK(gemm.run()); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); } timer.stop(); diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index 8766f081..2f4c9c76 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -21,3 +21,35 @@ To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easies The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements. It is well-suited for applying masks or activations. More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA. + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu b/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu new file mode 100644 index 00000000..058c4b2b --- /dev/null +++ b/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu @@ -0,0 +1,546 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale). + NVFP4 MMA has 2x throughput compared to MXFP8 MMA and 4x throughput compared to Ada Tensor Core FP8 MMA. + (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + This kernel leverages: + 1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Tensor Core MMA Instructions + 4. Epilogue Optimization + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + + $ ./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "79a_blackwell_geforce_nvfp4_bf16_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu b/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu new file mode 100644 index 00000000..e3ebba4a --- /dev/null +++ b/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu @@ -0,0 +1,593 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM120 architecture. + The kernel outputs quantized fp4 values with scale factors that will be the input of another GEMM. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages: + + 1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Tensor Core MMA Instructions + 4. Epilogue Optimization + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + + $ ./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand +using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFD matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand + +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +constexpr int InputSFVectorSize = 16; +constexpr int OutputSFVectorSize = InputSFVectorSize; + +// D = alpha * acc + beta * C +// With BlockScaleFactor generation. +using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, LayoutSFDTag, + ElementC>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong // Ping-pong kernel schedule policy. + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +using FusionOp = typename Gemm::EpilogueOutputOp; +constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; +using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig; +using LayoutSFD = typename SfdOutputCfg::LayoutSF; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +LayoutSFD layout_SFD; + +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +cutlass::HostTensor block_SFD; + +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +cutlass::HostTensor block_reference_SFD; +// Matrix-wide normalization constant +cutlass::HostTensor block_Normconst; + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "79b_blackwell_geforce_nvfp4_nvfp4_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + // For SFD tensor layout + using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_Normconst.reset(cutlass::make_Coord(1)); + + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_Normconst.at(cutlass::make_Coord(0)) = 2; + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + block_SFD.sync_device(); + block_Normconst.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + if constexpr (IsBlockScaleSupported) { + arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data(); + arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); + } + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + auto tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D), // TensorD + decltype(tensor_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu b/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu new file mode 100644 index 00000000..ac2f39c9 --- /dev/null +++ b/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu @@ -0,0 +1,546 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (mma.sync.aligned.block_scale). + MXFP8 MMA has 2x throughput compared to Ada Tensor Core FP8 MMA. + (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages: + 1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Tensor Core MMA Instructions + 4. Epilogue Optimization + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + + $ ./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::mx_float6_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "79c_blackwell_geforce_mixed_mxfp8_bf16_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/79_blackwell_geforce_gemm/CMakeLists.txt b/examples/79_blackwell_geforce_gemm/CMakeLists.txt new file mode 100644 index 00000000..cb7e3e97 --- /dev/null +++ b/examples/79_blackwell_geforce_gemm/CMakeLists.txt @@ -0,0 +1,47 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 120a) +cutlass_example_add_executable( + 79a_blackwell_geforce_nvfp4_bf16_gemm + 79a_blackwell_geforce_nvfp4_bf16_gemm.cu +) + +cutlass_example_add_executable( + 79b_blackwell_geforce_nvfp4_nvfp4_gemm + 79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu +) + +cutlass_example_add_executable( + 79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm + 79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu +) + +endif() diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu index 417830f2..3148d2aa 100644 --- a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu @@ -216,7 +216,7 @@ struct Options { out << "\n\nExamples:\n\n" - << "$ " << "81_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + << "$ " << "112_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; return out; } diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a1a5c00a..0f03cd9b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -157,6 +157,7 @@ foreach(EXAMPLE 76_blackwell_conv 77_blackwell_fmha 78_blackwell_emulated_bf16x9_gemm + 79_blackwell_geforce_gemm 81_blackwell_gemm_blockwise ) diff --git a/examples/README.md b/examples/README.md index 68bf7077..92779c07 100644 --- a/examples/README.md +++ b/examples/README.md @@ -282,6 +282,10 @@ Blackwell SM100 FastFP32 (using BF16 to emulate SGEMM) kernel +* [79_blackwell_geforce_gemm](79_blackwell_geforce_gemm/) + + Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores + # CuTe - Programming Examples Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/). @@ -291,3 +295,35 @@ Additionally, CuTe's core layout and layout algebra have their own test cases wi # Python Interface Examples Examples leveraging CUTLASS's [Python interface](../python/README.md) are located in [cutlass/examples/python](python/). + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/common/gather_tensor.hpp b/examples/common/gather_tensor.hpp index 67ae811b..46fb6400 100644 --- a/examples/common/gather_tensor.hpp +++ b/examples/common/gather_tensor.hpp @@ -58,7 +58,7 @@ struct IndexedGather operator()(I i) const { return indices_[i]; } CUTE_HOST_DEVICE friend - void + void print(IndexedGather const &s) { cute::print("Indexed"); } @@ -80,7 +80,7 @@ struct StridedGather operator()(I i) const { return i * stride_; } CUTE_HOST_DEVICE friend - void + void print(StridedGather const &s) { cute::print("Strided{"); print(s.stride_); @@ -153,7 +153,7 @@ make_custom_stride_layout(Stride const &stride, Func&& func) /// Helper function to optionally create a gather tensor template CUTLASS_HOST_DEVICE -auto +auto make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) { if constexpr (not cutlass::platform::is_same, NoGather>::value) { @@ -180,7 +180,7 @@ upcast(Shape const& shape, Stride const& stride) return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); } else if constexpr (is_scaled_basis::value) { if constexpr (Stride::mode() == I) { - return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); } else { return make_layout(shape, stride); } diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index f263e5ce..3c9e93c4 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -27,34 +27,31 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +add_subdirectory(hopper) +add_subdirectory(blackwell) cutlass_example_add_executable( - sgemm_1 + cute_tutorial_sgemm_1 sgemm_1.cu ) cutlass_example_add_executable( - sgemm_2 + cute_tutorial_sgemm_2 sgemm_2.cu ) cutlass_example_add_executable( - sgemm_sm70 + cute_tutorial_sgemm_sm70 sgemm_sm70.cu ) cutlass_example_add_executable( - sgemm_sm80 + cute_tutorial_sgemm_sm80 sgemm_sm80.cu ) cutlass_example_add_executable( - tiled_copy + cute_tutorial_tiled_copy tiled_copy.cu ) -cutlass_example_add_executable( - wgmma_sm90 - wgmma_sm90.cu -) - diff --git a/examples/cute/tutorial/blackwell/01_mma_sm100.cu b/examples/cute/tutorial/blackwell/01_mma_sm100.cu new file mode 100644 index 00000000..3f73140a --- /dev/null +++ b/examples/cute/tutorial/blackwell/01_mma_sm100.cu @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// CuTe Tutorial for SM100 Programming +// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used +// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces. +// +// The tutorial series is split into five stages: +// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction. +// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions. +// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA. +// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA. +// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +// Use Thrust to handle host/device allocations +#include +#include + +// Cutlass includes +#include // F16 data type +#include +#include +#include + +// CuTe includes +#include // CuTe tensor implementation +#include // CuTe functions for querying the details of cluster launched +#include // Compile time in constants such as _1, _256 etc. +#include + +// Tutorial helpers +#include "example_utils.hpp" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tutorial 01: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +// The goal of this tutorial is to show the CuTe interface for tcgen05.mma and tcgen05.ld operations. +// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where: +// - Matrix A is MxK, K-major (BLAS transpose T, row-major) +// - Matrix B is NxK, K-major (BLAS transpose N, column-major) +// - Matrices C and D are MxN, N-major (BLAS row-major) +// +// This GEMM kernel performs the following steps: +// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) for one MmaTile +// using auto-vectorizing copy operations. +// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +// 4. Read C matrix from global memory (GMEM) to register (RMEM). +// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix. +// 6. Store D matrix from registers (RMEM) to global memory (GMEM). +// +// SM100 tcgen05.mma instructions operate as follows: +// - Read matrix A from SMEM or TMEM +// - Read matrix B from SMEM +// - Write accumulator to TMEM +// The accumulator in TMEM must then be loaded to registers before writing back to GMEM. +// +// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types +// and the MMA's M and N dimensions. +// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors. +// These are the A and B fragments of the tcgen05.mma in CuTe terminology. +// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial. +// +// The MMA details: +// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA +// operation. F32 accumulator type is chosen since both C and D matrices use F32. +// This example uses F16xF16 = F32 MMA where: +// TypeA = cutlass::half_t; // MMA A Data Type +// TypeB = cutlass::half_t; // MMA B Data Type +// TypeC = float; // MMA C Data Type +// TypeD = float; // MMA D Data Type +// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// The shared memory buffers for A and B matrices. +template // (MmaB, NumMma_N, NumMma_K, ...) +struct SharedStorage +{ + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; + + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } +}; + +// The device kernel +template +__global__ static +void +gemm_device(ATensor mA, // (Gemm_M, Gemm_K) + BTensor mB, // (Gemm_N, Gemm_K) + CTensor mC, // (Gemm_M, Gemm_N) + DTensor mD, // (Gemm_M, Gemm_N) + MmaTiler_MNK mma_tiler, // + TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K> + ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK) + Alpha alpha, Beta beta) +{ + // Step 1: The Prologue. + + // The CTA layout within the Cluster: (V,M,N,K) -> CTA idx + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename TiledMMA::AtomThrID{})); + + // Construct the MMA grid coordinate from the CTA grid coordinate + auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate + blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate + blockIdx.y, // MMA-N coordinate + _); // MMA-K coordinate + + // Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed + // by this mma tile. + // CuTe provides local_tile partitioning function. local_tile accepts 4 parameters: + // * Tensor to partition + // * Tiler to use for partitioning + // * Coordinate to use for slicing the partitioned tensor + // * Projection to ignore unwanted modes of the Tiler and Coordinate + auto mma_coord = select<1,2,3>(mma_coord_vmnk); + Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K) + Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K) + Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + + if (thread0()) { + print("mA:\t"); print(mA); print("\n"); // mA: gmem_ptr[16b](GMEM_ADDR_A) o (512,256):(256,_1) + print("mB:\t"); print(mB); print("\n"); // mB: gmem_ptr[16b](GMEM_ADDR_B) o (1024,256):(256,_1) + print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1) + print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1) + + print("gA:\t"); print(gA); print("\n"); // gA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile) o (_128,_64,4):(256,_1,_64) + print("gB:\t"); print(gB); print("\n"); // gB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile) o (_256,_64,4):(_1,256,16384) + print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1) + print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1) + } __syncthreads(); + + // The SMEM tensors + + // Allocate SMEM + extern __shared__ char shared_memory[]; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Represent the SMEM buffers for A and B + Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // + // Mma partitioning for A and B + // + // Note: Partitioned tensors use tXgY naming convention: + // tXgY -> The partitioning pattern tX applied to tensor gY + + auto mma_v = get<0>(mma_coord_vmnk); + ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate + Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K) + Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N) + Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile + offset_for_mma) o ((_128,_16),_1,_4,4):((256,_1),_0,_16,_64) + print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile + offset_for_mma) o ((_256,_16),_1,_4,4):((_1,256),_0,4096,16384) + print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + } __syncthreads(); + + // MMA Fragment Allocation + // We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations. + // For tcgen05.mma operations: + // - Matrices A and B are sourced from SMEM + // - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively + // - The first mode of each descriptor represents the SMEM for a single MMA operation + Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // TMEM Allocation + // On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM). + // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. + Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) + } __syncthreads(); + + // Barrier Initialization + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + // Barriers in SMEM initialized by a single thread. + if (elect_one_warp && elect_one_thr) { + cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1); + } + int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + __syncthreads(); // Make sure all threads observe barrier initialization. + + // Step 2: The Mainloop. + + // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM + for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) + { + // Step 2a: Load A and B tiles + + // Using auto-vectorized copy operation: + // - Utilizes 128 threads for parallel data transfer + // - Copy operations are distributed efficiently across all threads + // - CuTe can automatically determine optimal vector width + cooperative_copy<128>(threadIdx.x, tCgA(_,_,_,k_tile), tCsA); // Load MmaTile_M x MmaTile_K A tile + cooperative_copy<128>(threadIdx.x, tCgB(_,_,_,k_tile), tCsB); // Load MmaTile_N x MmaTile_K B tile + + // Step 2b: Execute the MMAs for this tile + + // Wait for loads to SMEM to complete with __syncthreads() + __syncthreads(); + + // tcgen05.mma instructions require single-thread execution: + // - Only one warp performs the MMA-related loop operations + // - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp + // - No explicit elect_one_sync region is needed from the user + if (elect_one_warp) { + // Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + // Ensure MMAs are completed, only then we can reuse the A and B SMEM. + cutlass::arch::umma_arrive(&shared_storage.mma_barrier); + } + // Wait MMAs to complete to avoid overwriting the A and B SMEM. + cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); + mma_barrier_phase_bit ^= 1; + } + + // Step 3: The Epilogue. + + // Create the tiled copy operation for the accumulator (TMEM -> RMEM) + TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc); + ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x); + + Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N) + Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N) + // Load C tensor GMEM -> RMEM + copy(tDgC, tDrC); + + Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N) + Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N) + using AccType = typename decltype(tCtAcc)::value_type; + Tensor tDrAcc = make_tensor(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N) + // Load TMEM -> RMEM + copy(tiled_t2r_copy, tDtAcc, tDrAcc); + + // AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC + axpby(alpha, tDrAcc, beta, tDrC); + // Store RMEM -> GMEM + copy(tDrC, tDgD); +} + +template +void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, + TypeB const* device_ptr_B, LayoutB layout_B, + TypeC const* device_ptr_C, LayoutC layout_C, + TypeD * device_ptr_D, LayoutD layout_D, + Alpha const alpha, Beta const beta) +{ + assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M + assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M + assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N + assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N + assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K + + // Represent the full tensors in global memory + Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K) + Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K) + Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N) + Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N) + + // Get M, N, K dimensions of the GEMM we are running + auto Gemm_M = shape<0>(layout_A); + auto Gemm_N = shape<0>(layout_B); + auto Gemm_K = shape<1>(layout_A); + std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl; + + //////////////////////////////////////////////////////////// + // + // Initialize the GEMM kernel parameters + // + //////////////////////////////////////////////////////////// + + // Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a + // larger TiledMma from the given mma instruction. + // See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions + TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}); // A and B layouts + + // We can also print and inspect the tiled_mma + print(tiled_mma); + // TiledMMA + // ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0) + // PermutationMNK: (_,_,_) + // MMA_Atom + // ThrID: _1:_0 + // Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size + // LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix + // LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + // LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix + + // Define MMA tiler sizes (static) + auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M. + auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M. + auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16. + auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K) + + // In SM90, the MMAs are CTA-local and perform thread-level partitioning. + // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. + // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA + // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. + // The MMA's partitioning then yeilds the CTA-local work. + + if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { + std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; + return; + } + + if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) { + std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl; + return; + } + + // + // Determine the SMEM layouts: + // + + // * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions. + // * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape. + // These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3 + // where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time + // MMA instr is repeated in M/N mode and K mode of MMA tile, respectively. + // * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch. + + // Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K) + auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); + // Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K) + auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); + + // Print and inspect mma_shape_A, and mma_shape_B for this example. + print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4) + print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4) + + // A and B tensors are swizzled in SMEM to improve MMA performance. + // * However, expressing swizzled layouts is very hard. + // * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes + auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); + auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); + + // Print and inspect sA_layout and sB_layout for this example. + print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + + // Now we can find the SMEM allocation size + using SMEMStorage = SharedStorage; + + // The cluster shape and layout + auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{}); + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename decltype(tiled_mma)::AtomThrID{})); + + //////////////////////////////////////////////////////////// + // + // Launch GEMM kernel + // + //////////////////////////////////////////////////////////// + + dim3 dimBlock(128); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), + round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + int smemBytes = sizeof(SMEMStorage); + + auto* kernel_ptr = &gemm_device; + + // Set kernel attributes (set SMEM) + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); + + printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z); + printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, + mA, mB, mC, mD, + mma_tiler, tiled_mma, cluster_shape, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if ((props.major != 10) || (props.major == 10 && props.minor > 1)) { + std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl; + std::cerr << " Found " << props.major << "." << props.minor << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + int Gemm_M = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &Gemm_M); + + int Gemm_N = 1024; + if (argc >= 3) + sscanf(argv[2], "%d", &Gemm_N); + + int Gemm_K = 256; + if (argc >= 4) + sscanf(argv[3], "%d", &Gemm_K); + + //////////////////////////////////////////////////////////// + // + // Create A, B, C, and D tensors + // + //////////////////////////////////////////////////////////// + // Define the data types. A and B types are same for MMA instruction. + using TypeA = cutlass::half_t; // MMA A Data Type + auto type_str_a = "half_t"; + using TypeB = cutlass::half_t; // MMA B Data Type + auto type_str_b = "half_t"; + using TypeC = float; // MMA C Data Type + [[maybe_unused]] auto type_str_c = "float"; + using TypeD = float; // MMA D Data Type + auto type_str_d = "float"; + using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type. + + // A tensor MxK K-major (Layout T = Row-Major) + Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1) + // B tensor NxK K-major (Layout N = Column-Major) + Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1) + // C tensor MxN N-major (Layout T = Row-Major) + Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + // D tensor MxN N-major (Layout T = Row-Major) + Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + + // Host allocations and host CuTe tensors for A, B, and C tensors. + thrust::host_vector host_A(Gemm_M * Gemm_K); + Tensor host_tensor_A = make_tensor(host_A.data(), layout_A); + print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1) + + thrust::host_vector host_B(Gemm_N * Gemm_K); + Tensor host_tensor_B = make_tensor(host_B.data(), layout_B); + print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1) + + thrust::host_vector host_C(Gemm_M * Gemm_N); + Tensor host_tensor_C = make_tensor(host_C.data(), layout_C); + print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1) + + // Note that we don't need a host_tensor for D yet. + thrust::device_vector device_D(Gemm_M * Gemm_N); + + // Initialize A, B, and C tensors with random values. + initialize_tensor(host_tensor_A); + initialize_tensor(host_tensor_B); + initialize_tensor(host_tensor_C); + + // Copy A, B, and C tensors from host memory to device memory + thrust::device_vector device_A = host_A; + thrust::device_vector device_B = host_B; + thrust::device_vector device_C = host_C; + + using Alpha = float; + using Beta = float; + Alpha alpha = 1.0f; + Beta beta = 0.0f; + // Setup input and output tensors, and the kernel parameters; and execute the kernel on device + gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A, + device_B.data().get(), layout_B, + device_C.data().get(), layout_C, + device_D.data().get(), layout_D, + alpha, beta); + // Host allocation for D tensor and transfer D tensor from device to host + thrust::host_vector host_D = device_D; + // Create a non-owning CuTe tensor for D tensor + Tensor host_tensor_D = make_tensor(host_D.data(), layout_D); + + //////////////////////////////////////////////////////////// + // + // Execute reference GEMM kernel + // + //////////////////////////////////////////////////////////// + + thrust::host_vector host_reference_D(Gemm_M*Gemm_N); + auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D); + reference_gemm(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta); + + //////////////////////////////////////////////////////////// + // + // Compare results + // + //////////////////////////////////////////////////////////// + + auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A, + type_str_b, host_tensor_B, + type_str_d, host_tensor_D, host_reference_tensor_D); + bool success = relative_error <= 0.0; + std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl; +#else + std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; +} diff --git a/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu b/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu new file mode 100644 index 00000000..e508e552 --- /dev/null +++ b/examples/cute/tutorial/blackwell/02_mma_tma_sm100.cu @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// CuTe Tutorial for SM100 Programming +// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used +// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces. +// +// The tutorial series is split into five stages: +// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction. +// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions. +// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA. +// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA. +// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +// Use Thrust to handle host/device allocations +#include +#include + +// Cutlass includes +#include // F16 data type +#include +#include +#include + +// CuTe includes +#include // CuTe tensor implementation +#include // CuTe functions for querying the details of cluster launched +#include // Compile time in constants such as _1, _256 etc. +#include + +// Tutorial helpers +#include "example_utils.hpp" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tutorial 02: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where: +// - Matrix A is MxK, K-major (BLAS transpose T, row-major) +// - Matrix B is NxK, K-major (BLAS transpose N, column-major) +// - Matrices C and D are MxN, N-major (BLAS row-major) +// +// This GEMM kernel extends 01_mma_sm100.cu by adding Tensor Memory Access (TMA) and performs the following steps: +// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +// 4. Read C matrix from global memory (GMEM) to register (RMEM). +// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix. +// 6. Store D matrix from registers (RMEM) to global memory (GMEM). +// +// SM100 tcgen05.mma instructions operate as follows: +// - Read matrix A from SMEM or TMEM +// - Read matrix B from SMEM +// - Write accumulator to TMEM +// The accumulator in TMEM must then be loaded to registers before writing back to GMEM. +// +// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types +// and the MMA's M and N dimensions. +// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors. +// These are the A and B fragments of the tcgen05.mma in CuTe terminology. +// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial. +// +// The MMA details: +// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA +// operation. F32 accumulator type is chosen since both C and D matrices use F32. +// This example uses F16xF16 = F32 MMA where: +// TypeA = cutlass::half_t; // MMA A Data Type +// TypeB = cutlass::half_t; // MMA B Data Type +// TypeC = float; // MMA C Data Type +// TypeD = float; // MMA D Data Type +// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// The shared memory buffers for A and B matrices. +template // (MmaB, NumMma_N, NumMma_K, ...) +struct SharedStorage +{ + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; + + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } +}; + +// The device kernel +template +__global__ static +void +gemm_device(ATensor mA, // (Gemm_M, Gemm_K) + BTensor mB, // (Gemm_N, Gemm_K) + CTensor mC, // (Gemm_M, Gemm_N) + DTensor mD, // (Gemm_M, Gemm_N) + MmaTiler_MNK mma_tiler, // + TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K> + ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK) + CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A, + CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B, + Alpha alpha, Beta beta) +{ + // Step 1: The Prologue. + + // The CTA layout within the Cluster: (V,M,N,K) -> CTA idx + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename TiledMMA::AtomThrID{})); + + // Construct the MMA grid coordinate from the CTA grid coordinate + auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate + blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate + blockIdx.y, // MMA-N coordinate + _); // MMA-K coordinate + + // Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed + // by this mma tile. + // CuTe provides local_tile partitioning function. local_tile accepts 4 parameters: + // * Tensor to partition + // * Tiler to use for partitioning + // * Coordinate to use for slicing the partitioned tensor + // * Projection to ignore unwanted modes of the Tiler and Coordinate + auto mma_coord = select<1,2,3>(mma_coord_vmnk); + Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K) + Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K) + Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + + if (thread0()) { + print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0) + print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0) + print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1) + print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1) + + print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0) + print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0) + print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1) + print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1) + } __syncthreads(); + + // The SMEM tensors + + // Allocate SMEM + extern __shared__ char shared_memory[]; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Represent the SMEM buffers for A and B + Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // + // Mma partitioning for A and B + // + // Note: Partitioned tensors use tXgY naming convention: + // tXgY -> The partitioning pattern tX applied to tensor gY + + auto mma_v = get<0>(mma_coord_vmnk); + ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate + Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K) + Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N) + Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + } __syncthreads(); + + // MMA Fragment Allocation + // We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations. + // For tcgen05.mma operations: + // - Matrices A and B are sourced from SMEM + // - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively + // - The first mode of each descriptor represents the SMEM for a single MMA operation + Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // TMEM Allocation + // On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM). + // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. + Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) + } __syncthreads(); + + // TMA Setup + // + // These are TMA partitionings, which have a dedicated custom partitioner. + // The Int<0>, Layout<_1> indicates that the TMAs are not multicasted. + // Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host. + // For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK. + // For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK. + // Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy. + // The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info. + + auto [tAgA, tAsA] = tma_partition(tma_atom_A, + Int<0>{}, Layout<_1>{}, + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition(tma_atom_B, + Int<0>{}, Layout<_1>{}, + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + // Calculate total bytes that TMA will transfer each tile to track completion + int tma_transaction_bytes = sizeof(make_tensor_like(tAsA)) + + sizeof(make_tensor_like(tBsB)); + + if (thread0()) { + print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0)) + print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0)) + printf("TmaBytes: %d\n", tma_transaction_bytes); + } __syncthreads(); + + // Barrier Initialization + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + // Barriers in SMEM initialized by a single thread. + if (elect_one_warp && elect_one_thr) { + cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1); + cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1); + } + int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + __syncthreads(); // Make sure all threads observe barrier initialization. + + // Step 2: The Mainloop. + + // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM + for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) + { + // Step 2a: Load A and B tiles + + // TMA Load Operations: + // - Execute asynchronous TMA loads with single thread + // - Set transaction bytes and execute with barrier + if (elect_one_warp && elect_one_thr) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); + copy(tma_atom_A.with(shared_storage.tma_barrier), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile + copy(tma_atom_B.with(shared_storage.tma_barrier), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile + } + + // Step 2b: Execute the MMAs for this tile + + // Wait for TMA loads to SMEM to complete + cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit); + tma_barrier_phase_bit ^= 1; + + // tcgen05.mma instructions require single-thread execution: + // - Only one warp performs the MMA-related loop operations + // - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp + // - No explicit elect_one_sync region is needed from the user + if (elect_one_warp) { + // Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + // Ensure MMAs are completed, only then we can reuse the A and B SMEM. + cutlass::arch::umma_arrive(&shared_storage.mma_barrier); + } + // Wait MMAs to complete to avoid overwriting the A and B SMEM. + cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); + mma_barrier_phase_bit ^= 1; + } + + // Step 3: The Epilogue. + + // Create the tiled copy operation for the accumulator (TMEM -> RMEM) + TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc); + ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x); + + Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N) + Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N) + // Load C tensor GMEM -> RMEM + copy(tDgC, tDrC); + + Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N) + Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N) + using AccType = typename decltype(tCtAcc)::value_type; + Tensor tDrAcc = make_tensor(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N) + // Load TMEM -> RMEM + copy(tiled_t2r_copy, tDtAcc, tDrAcc); + + // AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC + axpby(alpha, tDrAcc, beta, tDrC); + // Store RMEM -> GMEM + copy(tDrC, tDgD); +} + +template +void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, + TypeB const* device_ptr_B, LayoutB layout_B, + TypeC const* device_ptr_C, LayoutC layout_C, + TypeD * device_ptr_D, LayoutD layout_D, + Alpha const alpha, Beta const beta) +{ + assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M + assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M + assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N + assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N + assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K + + // Represent the full tensors in global memory + Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K) + Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K) + Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N) + Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N) + + // Get M, N, K dimensions of the GEMM we are running + auto Gemm_M = shape<0>(layout_A); + auto Gemm_N = shape<0>(layout_B); + auto Gemm_K = shape<1>(layout_A); + std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl; + + //////////////////////////////////////////////////////////// + // + // Initialize the GEMM kernel parameters + // + //////////////////////////////////////////////////////////// + + // Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a + // larger TiledMma from the given mma instruction. + // See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions + TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}); // A and B layouts + + // We can also print and inspect the tiled_mma + print(tiled_mma); + // TiledMMA + // ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0) + // PermutationMNK: (_,_,_) + // MMA_Atom + // ThrID: _1:_0 + // Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size + // LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix + // LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + // LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix + + // Define MMA tiler sizes (static) + auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M. + auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M. + auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16. + auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K) + + // In SM90, the MMAs are CTA-local and perform thread-level partitioning. + // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. + // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA + // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. + // The MMA's partitioning then yeilds the CTA-local work. + + if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { + std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; + return; + } + + if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) { + std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl; + return; + } + + // + // Determine the SMEM layouts: + // + + // * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions. + // * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape. + // These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3 + // where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time + // MMA instr is repeated in M/N mode and K mode of MMA tile, respectively. + // * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch. + + // Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K) + auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); + // Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K) + auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); + + // Print and inspect mma_shape_A, and mma_shape_B for this example. + print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4) + print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4) + + // A and B tensors are swizzled in SMEM to improve MMA performance. + // * However, expressing swizzled layouts is very hard. + // * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes + auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); + auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); + + // Print and inspect sA_layout and sB_layout for this example. + print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + + // Now we can find the SMEM allocation size + using SMEMStorage = SharedStorage; + + // + // TMA Descriptor Creation (Host Side) + // + + // The cluster shape and layout + auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{}); + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename decltype(tiled_mma)::AtomThrID{})); + + // Create TMA descriptors for A and B matrices + Copy_Atom tma_atom_A = make_tma_atom( + SM90_TMA_LOAD{}, // TMA Load Op + mA, // Source GMEM tensor + sA_layout, // Destination SMEM layout + select<0,2>(mma_tiler) // MK Tiler for TMA operation + ); + Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K) + + print("tma_atom_A:\t"); print(tma_atom_A); print("\n"); + // tma_atom_A: Copy_Atom + // ThrID: _1:_0 + // ValLayoutSrc: (_1,_8192):(_0,_1) + // ValLayoutDst: (_1,_8192):(_0,_1) + // ValLayoutRef: (_1,_8192):(_0,_1) + // ValueType: 16b + + Copy_Atom tma_atom_B = make_tma_atom( + SM90_TMA_LOAD{}, // TMA Load Op + mB, // Source GMEM tensor + sB_layout, // Destination SMEM layout + select<1,2>(mma_tiler) // NK Tiler for TMA operation + ); + Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K) + + print("tma_atom_B:\t"); print(tma_atom_B); print("\n"); + // tma_atom_B: Copy_Atom + // ThrID: _1:_0 + // ValLayoutSrc: (_1,_16384):(_0,_1) + // ValLayoutDst: (_1,_16384):(_0,_1) + // ValLayoutRef: (_1,_16384):(_0,_1) + // ValueType: 16b + + //////////////////////////////////////////////////////////// + // + // Launch GEMM kernel + // + //////////////////////////////////////////////////////////// + + dim3 dimBlock(128); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), + round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + int smemBytes = sizeof(SMEMStorage); + + auto* kernel_ptr = &gemm_device; + + // Set kernel attributes (set SMEM) + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); + + printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z); + printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, + mA_tma, mB_tma, mC, mD, + mma_tiler, tiled_mma, cluster_shape, + tma_atom_A, tma_atom_B, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if ((props.major != 10) || (props.major == 10 && props.minor > 1)) { + std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl; + std::cerr << " Found " << props.major << "." << props.minor << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + int Gemm_M = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &Gemm_M); + + int Gemm_N = 1024; + if (argc >= 3) + sscanf(argv[2], "%d", &Gemm_N); + + int Gemm_K = 256; + if (argc >= 4) + sscanf(argv[3], "%d", &Gemm_K); + + //////////////////////////////////////////////////////////// + // + // Create A, B, C, and D tensors + // + //////////////////////////////////////////////////////////// + // Define the data types. A and B types are same for MMA instruction. + using TypeA = cutlass::half_t; // MMA A Data Type + auto type_str_a = "half_t"; + using TypeB = cutlass::half_t; // MMA B Data Type + auto type_str_b = "half_t"; + using TypeC = float; // MMA C Data Type + [[maybe_unused]] auto type_str_c = "float"; + using TypeD = float; // MMA D Data Type + auto type_str_d = "float"; + using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type. + + // A tensor MxK K-major (Layout T = Row-Major) + Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1) + // B tensor NxK K-major (Layout N = Column-Major) + Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1) + // C tensor MxN N-major (Layout T = Row-Major) + Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + // D tensor MxN N-major (Layout T = Row-Major) + Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + + // Host allocations and host CuTe tensors for A, B, and C tensors. + thrust::host_vector host_A(Gemm_M * Gemm_K); + Tensor host_tensor_A = make_tensor(host_A.data(), layout_A); + print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1) + + thrust::host_vector host_B(Gemm_N * Gemm_K); + Tensor host_tensor_B = make_tensor(host_B.data(), layout_B); + print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1) + + thrust::host_vector host_C(Gemm_M * Gemm_N); + Tensor host_tensor_C = make_tensor(host_C.data(), layout_C); + print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1) + + // Note that we don't need a host_tensor for D yet. + thrust::device_vector device_D(Gemm_M * Gemm_N); + + // Initialize A, B, and C tensors with random values. + initialize_tensor(host_tensor_A); + initialize_tensor(host_tensor_B); + initialize_tensor(host_tensor_C); + + // Copy A, B, and C tensors from host memory to device memory + thrust::device_vector device_A = host_A; + thrust::device_vector device_B = host_B; + thrust::device_vector device_C = host_C; + + using Alpha = float; + using Beta = float; + Alpha alpha = 1.0f; + Beta beta = 0.0f; + // Setup input and output tensors, and the kernel parameters; and execute the kernel on device + gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A, + device_B.data().get(), layout_B, + device_C.data().get(), layout_C, + device_D.data().get(), layout_D, + alpha, beta); + // Host allocation for D tensor and transfer D tensor from device to host + thrust::host_vector host_D = device_D; + // Create a non-owning CuTe tensor for D tensor + Tensor host_tensor_D = make_tensor(host_D.data(), layout_D); + + //////////////////////////////////////////////////////////// + // + // Execute reference GEMM kernel + // + //////////////////////////////////////////////////////////// + + thrust::host_vector host_reference_D(Gemm_M*Gemm_N); + auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D); + reference_gemm(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta); + + //////////////////////////////////////////////////////////// + // + // Compare results + // + //////////////////////////////////////////////////////////// + + auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A, + type_str_b, host_tensor_B, + type_str_d, host_tensor_D, host_reference_tensor_D); + bool success = relative_error <= 0.0; + std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl; +#else + std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; +} diff --git a/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu b/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu new file mode 100644 index 00000000..1c2538e3 --- /dev/null +++ b/examples/cute/tutorial/blackwell/03_mma_tma_multicast_sm100.cu @@ -0,0 +1,711 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// CuTe Tutorial for SM100 Programming +// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used +// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces. +// +// The tutorial series is split into five stages: +// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction. +// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions. +// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA. +// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA. +// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +// Use Thrust to handle host/device allocations +#include +#include + +// Cutlass includes +#include // F16 data type +#include +#include +#include + +// CuTe includes +#include // CuTe tensor implementation +#include // CuTe functions for querying the details of cluster launched +#include // Compile time in constants such as _1, _256 etc. +#include + +// Tutorial helpers +#include "example_utils.hpp" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tutorial 03: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where: +// - Matrix A is MxK, K-major (BLAS transpose T, row-major) +// - Matrix B is NxK, K-major (BLAS transpose N, column-major) +// - Matrices C and D are MxN, N-major (BLAS row-major) +// +// Key extensions from tutorial 02_mma_tma_sm100.cu: +// 1. Introduce ClusterShape for coordinated execution across thread blocks +// 2. Introduce TMA multicast +// 3. Enhanced TMA <-> MMA synchronization for cluster-wide operations +// +// This GEMM kernel will perform the following steps: +// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA load operations. +// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +// 4. Read C matrix from global memory (GMEM) to register (RMEM). +// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix. +// 6. Store D matrix from registers (RMEM) to global memory (GMEM). +// +// SM100 tcgen05.mma instructions operate as follows: +// - Read matrix A from SMEM or TMEM +// - Read matrix B from SMEM +// - Write accumulator to TMEM +// The accumulator in TMEM must then be loaded to registers before writing back to GMEM. +// +// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types +// and the MMA's M and N dimensions. +// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors. +// These are the A and B fragments of the tcgen05.mma in CuTe terminology. +// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial. +// +// The MMA details: +// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA +// operation. F32 accumulator type is chosen since both C and D matrices use F32. +// This example uses F16xF16 = F32 MMA where: +// TypeA = cutlass::half_t; // MMA A Data Type +// TypeB = cutlass::half_t; // MMA B Data Type +// TypeC = float; // MMA C Data Type +// TypeD = float; // MMA D Data Type +// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// The shared memory buffers for A and B matrices. +template // (MmaB, NumMma_N, NumMma_K, ...) +struct SharedStorage +{ + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; + + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } +}; + +// The device kernel +template +__global__ static +void +gemm_device(ATensor mA, // (Gemm_M, Gemm_K) + BTensor mB, // (Gemm_N, Gemm_K) + CTensor mC, // (Gemm_M, Gemm_N) + DTensor mD, // (Gemm_M, Gemm_N) + MmaTiler_MNK mma_tiler, // + TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K> + ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK) + CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A, + CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B, + Alpha alpha, Beta beta) +{ + // Step 1: The Prologue. + + // The CTA layout within the Cluster: (V,M,N,K) -> CTA idx + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename TiledMMA::AtomThrID{})); + + // Construct the MMA grid coordinate from the CTA grid coordinate + auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate + blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate + blockIdx.y, // MMA-N coordinate + _); // MMA-K coordinate + + // Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed + // by this mma tile. + // CuTe provides local_tile partitioning function. local_tile accepts 4 parameters: + // * Tensor to partition + // * Tiler to use for partitioning + // * Coordinate to use for slicing the partitioned tensor + // * Projection to ignore unwanted modes of the Tiler and Coordinate + auto mma_coord = select<1,2,3>(mma_coord_vmnk); + Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K) + Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K) + Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + + if (thread0()) { + print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0) + print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0) + print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1) + print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1) + + print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0) + print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0) + print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1) + print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1) + } __syncthreads(); + + // The SMEM tensors + + // Allocate SMEM + extern __shared__ char shared_memory[]; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Represent the SMEM buffers for A and B + Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // + // Mma partitioning for A and B + // + + auto mma_v = get<0>(mma_coord_vmnk); + ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate + Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K) + Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N) + Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + } __syncthreads(); + + // MMA Fragment Allocation + // We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations. + // For tcgen05.mma operations: + // - Matrices A and B are sourced from SMEM + // - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively + // - The first mode of each descriptor represents the SMEM for a single MMA operation + Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // TMEM Allocation + // On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM). + // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. + Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) + } __syncthreads(); + + // TMA Setup + // + // These are TMA partitionings, which have a dedicated custom partitioner. + // In this example, the TMA multicasts the loads across multiple CTAs. + // Loads of A are multicasted along the N dimension of the cluster_shape_MNK and + // Loads of B are multicasted along the M dimension of the cluster_shape_MNK. + // Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host. + // For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK. + // For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK. + // Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy. + // The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info. + + // Each CTA with the same m-coord will load a portion of A + // Each CTA with the same n-coord will load a portion of B + // Multicast behavior for CTA 1,2 in the cluster + // A multicast B multicast + // 0 1 2 3 0 1 2 3 + // 0 - - - - 0 - - X - + // 1 X X X X 1 - - X - + // 2 - - - - 2 - - X - + // 3 - - - - 3 - - X - + // tma_multicast_mask_A = 0x2222 + // tma_multicast_mask_B = 0x0F00 + // mma_multicast_mask_C = 0x2F22 + + // Construct the CTA-in-Cluster coordinate for multicasting + auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + + // Project the cluster_layout for tma_A along the N-modes + auto [tAgA, tAsA] = tma_partition(tma_atom_A, + get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster + make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + // Project the cluster_layout for tma_B along the M-modes + auto [tBgB, tBsB] = tma_partition(tma_atom_B, + get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster + make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + // Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + // Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + // Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C + uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) | + create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + + // Calculate total bytes that TMA will transfer each tile to track completion + int tma_transaction_bytes = sizeof(make_tensor_like(tAsA)) + + sizeof(make_tensor_like(tBsB)); + + if (thread0()) { + print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0)) + print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0)) + printf("tma_transaction_bytes: %d\n", tma_transaction_bytes); + printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a); + printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b); + printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c); + } __syncthreads(); + + // Barrier Initialization + + uint32_t elect_one_thr = cute::elect_one_sync(); + uint32_t elect_one_warp = (threadIdx.x / 32 == 0); + + // Barriers in SMEM initialized by a single thread. + if (elect_one_warp && elect_one_thr) { + // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) + int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1; + cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants); + cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1); + } + int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + cute::cluster_sync(); // Make sure all threads across all CTAs in Cluster observe barrier initialization. + + // Step 2: The Mainloop. + + // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM + for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) + { + // Step 2a: Load A and B tiles + + // TMA Load Operations: + // - Execute asynchronous TMA loads with single thread + // - Set transaction bytes and execute with barrier + if (elect_one_warp && elect_one_thr) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); + copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile + copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile + } + + // Step 2b: Execute the MMAs for this tile + + // Wait for TMA loads to SMEM to complete + cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit); + tma_barrier_phase_bit ^= 1; + + // tcgen05.mma instructions require single-thread execution: + // - Only one warp performs the MMA-related loop operations + // - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp + // - No explicit elect_one_sync region is needed from the user + if (elect_one_warp) { + // Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + // Ensure MMAs are completed, only then we can reuse the A and B SMEM. + cutlass::arch::umma_arrive_multicast(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask. + } + // Wait MMAs to complete to avoid overwriting the A and B SMEM. + cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); + mma_barrier_phase_bit ^= 1; + } + + // Step 3: The Epilogue. + + // Create the tiled copy operation for the accumulator (TMEM -> RMEM) + TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc); + ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x); + + Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N) + Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N) + // Load C tensor GMEM -> RMEM + copy(tDgC, tDrC); + + Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N) + Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N) + using AccType = typename decltype(tCtAcc)::value_type; + Tensor tDrAcc = make_tensor(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N) + // Load TMEM -> RMEM + copy(tiled_t2r_copy, tDtAcc, tDrAcc); + + // AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC + axpby(alpha, tDrAcc, beta, tDrC); + // Store RMEM -> GMEM + copy(tDrC, tDgD); +} + +template +void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, + TypeB const* device_ptr_B, LayoutB layout_B, + TypeC const* device_ptr_C, LayoutC layout_C, + TypeD * device_ptr_D, LayoutD layout_D, + Alpha const alpha, Beta const beta) +{ + assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M + assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M + assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N + assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N + assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K + + // Represent the full tensors in global memory + Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K) + Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K) + Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N) + Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N) + + // Get M, N, K dimensions of the GEMM we are running + auto Gemm_M = shape<0>(layout_A); + auto Gemm_N = shape<0>(layout_B); + auto Gemm_K = shape<1>(layout_A); + std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl; + + //////////////////////////////////////////////////////////// + // + // Initialize the GEMM kernel parameters + // + //////////////////////////////////////////////////////////// + + // Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a + // larger TiledMma from the given mma instruction. + // See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions + TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}); // A and B layouts + + // We can also print and inspect the tiled_mma + print(tiled_mma); + // TiledMMA + // ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0) + // PermutationMNK: (_,_,_) + // MMA_Atom + // ThrID: _1:_0 + // Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size + // LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix + // LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + // LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix + + // Define MMA tiler sizes (static) + auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M. + auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M. + auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16. + auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K) + + // In SM90, the MMAs are CTA-local and perform thread-level partitioning. + // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. + // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA + // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. + // The MMA's partitioning then yeilds the CTA-local work. + + if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { + std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; + return; + } + + if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) { + std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl; + return; + } + + // + // Determine the SMEM layouts: + // + + // * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions. + // * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape. + // These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3 + // where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time + // MMA instr is repeated in M/N mode and K mode of MMA tile, respectively. + // * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch. + + // Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K) + auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); + // Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K) + auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); + + // Print and inspect mma_shape_A, and mma_shape_B for this example. + print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4) + print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4) + + // A and B tensors are swizzled in SMEM to improve MMA performance. + // * However, expressing swizzled layouts is very hard. + // * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes + auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); + auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); + + // Print and inspect sA_layout and sB_layout for this example. + print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + + // Now we can find the SMEM allocation size + using SMEMStorage = SharedStorage; + + // + // TMA Descriptor Creation (Host Side) + // + + // The cluster shape and layout + auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{}); + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename decltype(tiled_mma)::AtomThrID{})); + + Copy_Atom tma_atom_A = make_tma_atom( + SM90_TMA_LOAD_MULTICAST{}, // TMA load operation with multicast + mA, // Source GMEM tensor + sA_layout, // Destination SMEM layout + select<0,2>(mma_tiler), // MK Tiler for TMA operation + size<2>(cluster_layout_vmnk) // The number of CTAs in the N-mode for multicasting + ); + Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K) + + print("tma_atom_A:\t"); print(tma_atom_A); print("\n"); + // tma_atom_A: Copy_Atom + // ThrID: _1:_0 + // ValLayoutSrc: (_1,_8192):(_0,_1) + // ValLayoutDst: (_1,_8192):(_0,_1) + // ValLayoutRef: (_1,_8192):(_0,_1) + // ValueType: 16b + + Copy_Atom tma_atom_B = make_tma_atom( + SM90_TMA_LOAD_MULTICAST{}, // TMA load operation with multicast + mB, // Source GMEM tensor + sB_layout, // Destination SMEM layout + select<1,2>(mma_tiler), // NK Tiler for TMA operation + size<1>(cluster_layout_vmnk) // The number of CTAs in the M-mode for multicasting + ); + Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K) + + print("tma_atom_B:\t"); print(tma_atom_B); print("\n"); + // tma_atom_B: Copy_Atom + // ThrID: _1:_0 + // ValLayoutSrc: (_1,_16384):(_0,_1) + // ValLayoutDst: (_1,_16384):(_0,_1) + // ValLayoutRef: (_1,_16384):(_0,_1) + // ValueType: 16b + + //////////////////////////////////////////////////////////// + // + // Launch GEMM kernel + // + //////////////////////////////////////////////////////////// + + dim3 dimBlock(128); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), + round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + int smemBytes = sizeof(SMEMStorage); + + auto* kernel_ptr = &gemm_device; + + // Set kernel attributes (set SMEM) + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); + + printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z); + printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, + mA_tma, mB_tma, mC, mD, + mma_tiler, tiled_mma, cluster_shape, + tma_atom_A, tma_atom_B, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if ((props.major != 10) || (props.major == 10 && props.minor > 1)) { + std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl; + std::cerr << " Found " << props.major << "." << props.minor << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + int Gemm_M = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &Gemm_M); + + int Gemm_N = 1024; + if (argc >= 3) + sscanf(argv[2], "%d", &Gemm_N); + + int Gemm_K = 256; + if (argc >= 4) + sscanf(argv[3], "%d", &Gemm_K); + + //////////////////////////////////////////////////////////// + // + // Create A, B, C, and D tensors + // + //////////////////////////////////////////////////////////// + // Define the data types. A and B types are same for MMA instruction. + using TypeA = cutlass::half_t; // MMA A Data Type + auto type_str_a = "half_t"; + using TypeB = cutlass::half_t; // MMA B Data Type + auto type_str_b = "half_t"; + using TypeC = float; // MMA C Data Type + [[maybe_unused]] auto type_str_c = "float"; + using TypeD = float; // MMA D Data Type + auto type_str_d = "float"; + using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type. + + // A tensor MxK K-major (Layout T = Row-Major) + Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1) + // B tensor NxK K-major (Layout N = Column-Major) + Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1) + // C tensor MxN N-major (Layout T = Row-Major) + Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + // D tensor MxN N-major (Layout T = Row-Major) + Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + + // Host allocations and host CuTe tensors for A, B, and C tensors. + thrust::host_vector host_A(Gemm_M * Gemm_K); + Tensor host_tensor_A = make_tensor(host_A.data(), layout_A); + print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1) + + thrust::host_vector host_B(Gemm_N * Gemm_K); + Tensor host_tensor_B = make_tensor(host_B.data(), layout_B); + print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1) + + thrust::host_vector host_C(Gemm_M * Gemm_N); + Tensor host_tensor_C = make_tensor(host_C.data(), layout_C); + print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1) + + // Note that we don't need a host_tensor for D yet. + thrust::device_vector device_D(Gemm_M * Gemm_N); + + // Initialize A, B, and C tensors with random values. + initialize_tensor(host_tensor_A); + initialize_tensor(host_tensor_B); + initialize_tensor(host_tensor_C); + + // Copy A, B, and C tensors from host memory to device memory + thrust::device_vector device_A = host_A; + thrust::device_vector device_B = host_B; + thrust::device_vector device_C = host_C; + + using Alpha = float; + using Beta = float; + Alpha alpha = 1.0f; + Beta beta = 0.0f; + // Setup input and output tensors, and the kernel parameters; and execute the kernel on device + gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A, + device_B.data().get(), layout_B, + device_C.data().get(), layout_C, + device_D.data().get(), layout_D, + alpha, beta); + // Host allocation for D tensor and transfer D tensor from device to host + thrust::host_vector host_D = device_D; + // Create a non-owning CuTe tensor for D tensor + Tensor host_tensor_D = make_tensor(host_D.data(), layout_D); + + //////////////////////////////////////////////////////////// + // + // Execute reference GEMM kernel + // + //////////////////////////////////////////////////////////// + + thrust::host_vector host_reference_D(Gemm_M*Gemm_N); + auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D); + reference_gemm(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta); + + //////////////////////////////////////////////////////////// + // + // Compare results + // + //////////////////////////////////////////////////////////// + + auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A, + type_str_b, host_tensor_B, + type_str_d, host_tensor_D, host_reference_tensor_D); + bool success = relative_error <= 0.0; + std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl; +#else + std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; +} diff --git a/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu b/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu new file mode 100644 index 00000000..290436ea --- /dev/null +++ b/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu @@ -0,0 +1,716 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// CuTe Tutorial for SM100 Programming +// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used +// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces. +// +// The tutorial series is split into five stages: +// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction. +// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions. +// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA. +// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA. +// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +// Use Thrust to handle host/device allocations +#include +#include + +// Cutlass includes +#include // F16 data type +#include +#include +#include + +// CuTe includes +#include // CuTe tensor implementation +#include // CuTe functions for querying the details of cluster launched +#include // Compile time in constants such as _1, _256 etc. +#include + +// Tutorial helpers +#include "example_utils.hpp" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tutorial 04: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where: +// - Matrix A is MxK, K-major (BLAS transpose T, row-major) +// - Matrix B is NxK, K-major (BLAS transpose N, column-major) +// - Matrices C and D are MxN, N-major (BLAS row-major) +// +// Key extensions to tutorial 03_mma_tma_multicast_sm100.cu: +// 1. Introduce 2SM tcgen05.mma instructions +// 2. Introduce 2SM TMA instructions +// 3. Demonstrate TMA multicast pattern specialized for 2SM instructions for loading A and B matrices +// +// This GEMM kernel will perform the following steps: +// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA.2SM load operations. +// 2. Perform matrix multiply-accumulate (MMA) operations using 2SM tcgen05.mma instruction. +// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +// 4. Read C matrix from global memory (GMEM) to register (RMEM). +// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix. +// 6. Store D matrix from registers (RMEM) to global memory (GMEM). +// +// SM100 2SM tcgen05.mma instructions operate as follows: +// - Mma is launched by only one SM +// With 2SM MMA instructions, only 1 of the 2 CTAs collaborating on MMA executes the instruction. +// We call the collaborating CTAs, peer CTAs. And the CTA executing the MMA instruction is called leader CTA. +// - Read matrix A from SMEM or TMEM +// - Read matrix B from SMEM +// - Write accumulator to TMEM +// The accumulator in TMEM must then be loaded to registers before writing back to GMEM. +// +// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types +// and the MMA's M and N dimensions. +// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors. +// These are the A and B fragments of the tcgen05.mma in CuTe terminology. +// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial. +// +// The MMA details: +// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 256x256x16 MMA +// operation. F32 accumulator type is chosen since both C and D matrices use F32. +// This example uses F16xF16 = F32 MMA where: +// TypeA = cutlass::half_t; // MMA A Data Type +// TypeB = cutlass::half_t; // MMA B Data Type +// TypeC = float; // MMA C Data Type +// TypeD = float; // MMA D Data Type +// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// The shared memory buffers for A and B matrices. +template // (MmaB, NumMma_N, NumMma_K, ...) +struct SharedStorage +{ + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; + + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } +}; + +// The device kernel +template +__global__ static +void +gemm_device(ATensor mA, // (Gemm_M, Gemm_K) + BTensor mB, // (Gemm_N, Gemm_K) + CTensor mC, // (Gemm_M, Gemm_N) + DTensor mD, // (Gemm_M, Gemm_N) + MmaTiler_MNK mma_tiler, // + TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K> + ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK) + CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A, + CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B, + Alpha alpha, Beta beta) +{ + // Step 1: The Prologue. + + // The CTA layout within the Cluster: (V,M,N,K) -> CTA idx + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename TiledMMA::AtomThrID{})); + + // Construct the MMA grid coordinate from the CTA grid coordinate + auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate + blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate + blockIdx.y, // MMA-N coordinate + _); // MMA-K coordinate + + // Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed + // by this mma tile. + // CuTe provides local_tile partitioning function. local_tile accepts 4 parameters: + // * Tensor to partition + // * Tiler to use for partitioning + // * Coordinate to use for slicing the partitioned tensor + // * Projection to ignore unwanted modes of the Tiler and Coordinate + auto mma_coord = select<1,2,3>(mma_coord_vmnk); + Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K) + Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K) + Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + + if (thread0()) { + print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0) + print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0) + print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1) + print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1) + + print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0) + print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0) + print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1) + print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1) + } __syncthreads(); + + // The SMEM tensors + + // Allocate SMEM + extern __shared__ char shared_memory[]; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Represent the SMEM buffers for A and B + Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // + // Mma partitioning for A and B + // + + auto mma_v = get<0>(mma_coord_vmnk); + ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate + Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K) + Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N) + Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + } __syncthreads(); + + // MMA Fragment Allocation + // We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations. + // For tcgen05.mma operations: + // - Matrices A and B are sourced from SMEM + // - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively + // - The first mode of each descriptor represents the SMEM for a single MMA operation + Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // TMEM Allocation + // On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM). + // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. + Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) + } __syncthreads(); + + // TMA Setup + // + // These are TMA partitionings, which have a dedicated custom partitioner. + // In this example, the TMA multicasts the loads across multiple CTAs. + // Loads of A are multicasted along the N dimension of the cluster_shape_VMNK and + // Loads of B are multicasted along the M dimension of the cluster_shape_VMNK. + // Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host. + // For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK. + // For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK. + // Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy. + // The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info. + + // Each CTA with the same m-coord will load a portion of A + // Each CTA with the same n-coord will load a portion of B + // Computation of the multicast masks must take into account the Peer CTA for TMA.2SM + + // Construct the CTA-in-Cluster coordinate for multicasting + auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + + // Project the cluster_layout for tma_A along the N-modes + auto [tAgA, tAsA] = tma_partition(tma_atom_A, + get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster + make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + // Project the cluster_layout for tma_B along the M-modes + auto [tBgB, tBsB] = tma_partition(tma_atom_B, + get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster + make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + // Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + // Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + // Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C + uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) | + create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + + // Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM + int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA)) + + size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB)); + + if (thread0()) { + print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0)) + print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0)) + printf("tma_transaction_bytes: %d\n", tma_transaction_bytes); + printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a); + printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b); + printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c); + } __syncthreads(); + + // Barrier Initialization + auto elect_one_thr = cute::elect_one_sync(); + auto elect_one_warp = (threadIdx.x / 32 == 0); + auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; + + // Barriers in SMEM should be initialized by a single thread. + if (elect_one_warp && elect_one_thr) { + // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) + int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1; + cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants); + cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1); + } + int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + cute::cluster_sync(); // Make sure all CTAs in Cluster observe barrier init and TMEM alloc. + + // Step 2: The Mainloop. + + // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM + for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) + { + // Step 2a: Load A and B tiles + + // TMA Load Operations: + // - Execute asynchronous TMA loads with single thread + // - Both peer and leader CTAs initiate TMA loads + // - Set expected transaction bytes. For 2SM TMA instructions, the transaction bytes counts both CTAs. + // - Although TMAs are initiated by both peer and leader CTAs, the barrier is only set and waited by the leader CTA. + // - Initiate asynchronous transfers with a multicast mask that includes all CTAs that participate in multicast. + if (elect_one_warp && elect_one_thr) { // TMA loads are executed by one thread + if (elect_one_cta) { // Only the leader CTA waits for TMA transactions + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); // Set the expected transaction bytes for the TMA loads + } + copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile + copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile + } + + // Step 2b: Execute the MMAs for this tile + + if (elect_one_cta) { + // Wait for TMA loads to complete on leader CTAs + cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit); + tma_barrier_phase_bit ^= 1; + + // tcgen05.mma instructions require single-thread execution: + // - Only one warp performs the MMA-related loop operations + // - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp + // - No explicit elect_one_sync region is needed from the user + if (elect_one_warp) { + // Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + // Ensure MMAs are completed, only then we can reuse the A and B SMEM. + cutlass::arch::umma_arrive_multicast_2x1SM(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask. + } + } + // Wait MMAs to complete to avoid overwriting the A and B SMEM. + cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); + mma_barrier_phase_bit ^= 1; + } + + // Step 3: The Epilogue. + + // Create the tiled copy operation for the accumulator (TMEM -> RMEM) + TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc); + ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x); + + Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N) + Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N) + // Load C tensor GMEM -> RMEM + copy(tDgC, tDrC); + + Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N) + Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N) + using AccType = typename decltype(tCtAcc)::value_type; + Tensor tDrAcc = make_tensor(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N) + // Load TMEM -> RMEM + copy(tiled_t2r_copy, tDtAcc, tDrAcc); + + // AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC + axpby(alpha, tDrAcc, beta, tDrC); + // Store RMEM -> GMEM + copy(tDrC, tDgD); +} + +template +void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, + TypeB const* device_ptr_B, LayoutB layout_B, + TypeC const* device_ptr_C, LayoutC layout_C, + TypeD * device_ptr_D, LayoutD layout_D, + Alpha const alpha, Beta const beta) +{ + assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M + assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M + assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N + assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N + assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K + + // Represent the full tensors in global memory + Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K) + Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K) + Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N) + Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N) + + // Get M, N, K dimensions of the GEMM we are running + auto Gemm_M = shape<0>(layout_A); + auto Gemm_N = shape<0>(layout_B); + auto Gemm_K = shape<1>(layout_A); + std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl; + + //////////////////////////////////////////////////////////// + // + // Initialize the GEMM kernel parameters + // + //////////////////////////////////////////////////////////// + + // Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a + // larger TiledMma from the given mma instruction. + // See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions + TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_2x1SM_SS{}); // A and B layouts + + // We can also print and inspect the tiled_mma + print(tiled_mma); + // TiledMMA + // ThrLayoutVMNK: (_2,_1,_1,_1):(_1,_0,_0,_0) + // PermutationMNK: (_,_,_) + // MMA_Atom + // ThrID: _2:_1 + // Shape_MNK: (_256,_256,_16) // MmaM, MmaN, MmaK (MmaK is constant for each instr.) + // LayoutA_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for A matrix + // LayoutB_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + // LayoutC_TV: (_2,(_128,_256)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + + // Define MMA tiler sizes (static) + auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M. + auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M. + auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16. + auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K) + + // In SM90, the MMAs are CTA-local and perform thread-level partitioning. + // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. + // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA + // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. + // The MMA's partitioning then yeilds the CTA-local work. + + if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { + std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; + return; + } + + if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) { + std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl; + return; + } + + // + // Determine the SMEM layouts: + // + + // * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions. + // * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape. + // These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3 + // where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time + // MMA instr is repeated in M/N mode and K mode of MMA tile, respectively. + // * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch. + + // Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K) + auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); + // Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K) + auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); + + // Print and inspect mma_shape_A, and mma_shape_B for this example. + print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4) + print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4) + + // A and B tensors are swizzled in SMEM to improve MMA performance. + // * However, expressing swizzled layouts is very hard. + // * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes + auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); + auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); + + // Print and inspect sA_layout and sB_layout for this example. + print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + + // Now we can find the SMEM allocation size + using SMEMStorage = SharedStorage; + + // + // TMA Descriptor Creation (Host Side) + // + + // The cluster shape and layout + auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{}); + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename decltype(tiled_mma)::AtomThrID{})); + + // SM100 interface for creating TMA loads. + Copy_Atom tma_atom_A = make_tma_atom_A_sm100( + SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction. + mA, // Source GMEM tensor + sA_layout, // Destination SMEM layout + mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes. + tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning. + cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed. + // We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode. + Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K) + + print("tma_atom_A:\t"); print(tma_atom_A); print("\n"); + // tma_atom_A: Copy_Atom + // ThrID: _2:_1 + // ValLayoutSrc: (_2,_8192):(_8192,_1) + // ValLayoutDst: (_2,_8192):(_8192,_1) + // ValLayoutRef: (_2,_8192):(_8192,_1) + // ValueType: 16b + + // SM100 interface for creating TMA loads. + Copy_Atom tma_atom_B = make_tma_atom_B_sm100( + SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction. + mB, // Source GMEM tensor + sB_layout, // Destination SMEM layout + mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes. + tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning. + cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed. + // We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode. + Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K) + + print("tma_atom_B:\t"); print(tma_atom_B); print("\n"); + // tma_atom_B: Copy_Atom + // ThrID: _2:_1 + // ValLayoutSrc: (_2,_8192):(_8192,_1) + // ValLayoutDst: (_2,_8192):(_8192,_1) + // ValLayoutRef: (_2,_8192):(_8192,_1) + // ValueType: 16b + + //////////////////////////////////////////////////////////// + // + // Launch GEMM kernel + // + //////////////////////////////////////////////////////////// + + dim3 dimBlock(128); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), + round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + int smemBytes = sizeof(SMEMStorage); + + auto* kernel_ptr = &gemm_device; + + // Set kernel attributes (set SMEM) + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); + + printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z); + printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, + mA_tma, mB_tma, mC, mD, + mma_tiler, tiled_mma, cluster_shape, + tma_atom_A, tma_atom_B, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if ((props.major != 10) || (props.major == 10 && props.minor > 1)) { + std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl; + std::cerr << " Found " << props.major << "." << props.minor << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + int Gemm_M = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &Gemm_M); + + int Gemm_N = 1024; + if (argc >= 3) + sscanf(argv[2], "%d", &Gemm_N); + + int Gemm_K = 256; + if (argc >= 4) + sscanf(argv[3], "%d", &Gemm_K); + + //////////////////////////////////////////////////////////// + // + // Create A, B, C, and D tensors + // + //////////////////////////////////////////////////////////// + // Define the data types. A and B types are same for MMA instruction. + using TypeA = cutlass::half_t; // MMA A Data Type + auto type_str_a = "half_t"; + using TypeB = cutlass::half_t; // MMA B Data Type + auto type_str_b = "half_t"; + using TypeC = float; // MMA C Data Type + [[maybe_unused]] auto type_str_c = "float"; + using TypeD = float; // MMA D Data Type + auto type_str_d = "float"; + using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type. + + // A tensor MxK K-major (Layout T = Row-Major) + Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1) + // B tensor NxK K-major (Layout N = Column-Major) + Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1) + // C tensor MxN N-major (Layout T = Row-Major) + Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + // D tensor MxN N-major (Layout T = Row-Major) + Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + + // Host allocations and host CuTe tensors for A, B, and C tensors. + thrust::host_vector host_A(Gemm_M * Gemm_K); + Tensor host_tensor_A = make_tensor(host_A.data(), layout_A); + print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1) + + thrust::host_vector host_B(Gemm_N * Gemm_K); + Tensor host_tensor_B = make_tensor(host_B.data(), layout_B); + print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1) + + thrust::host_vector host_C(Gemm_M * Gemm_N); + Tensor host_tensor_C = make_tensor(host_C.data(), layout_C); + print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1) + + // Note that we don't need a host_tensor for D yet. + thrust::device_vector device_D(Gemm_M * Gemm_N); + + // Initialize A, B, and C tensors with random values. + initialize_tensor(host_tensor_A); + initialize_tensor(host_tensor_B); + initialize_tensor(host_tensor_C); + + // Copy A, B, and C tensors from host memory to device memory + thrust::device_vector device_A = host_A; + thrust::device_vector device_B = host_B; + thrust::device_vector device_C = host_C; + + using Alpha = float; + using Beta = float; + Alpha alpha = 1.0f; + Beta beta = 0.0f; + // Setup input and output tensors, and the kernel parameters; and execute the kernel on device + gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A, + device_B.data().get(), layout_B, + device_C.data().get(), layout_C, + device_D.data().get(), layout_D, + alpha, beta); + // Host allocation for D tensor and transfer D tensor from device to host + thrust::host_vector host_D = device_D; + // Create a non-owning CuTe tensor for D tensor + Tensor host_tensor_D = make_tensor(host_D.data(), layout_D); + + //////////////////////////////////////////////////////////// + // + // Execute reference GEMM kernel + // + //////////////////////////////////////////////////////////// + + thrust::host_vector host_reference_D(Gemm_M*Gemm_N); + auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D); + reference_gemm(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta); + + //////////////////////////////////////////////////////////// + // + // Compare results + // + //////////////////////////////////////////////////////////// + + auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A, + type_str_b, host_tensor_B, + type_str_d, host_tensor_D, host_reference_tensor_D); + bool success = relative_error <= 0.0; + std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl; +#else + std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; +} diff --git a/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu b/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu new file mode 100644 index 00000000..6d9ab03f --- /dev/null +++ b/examples/cute/tutorial/blackwell/05_mma_tma_epi_sm100.cu @@ -0,0 +1,825 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// CuTe Tutorial for SM100 Programming +// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used +// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces. +// +// The tutorial series is split into five stages: +// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction. +// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions. +// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA. +// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA. +// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +// Use Thrust to handle host/device allocations +#include +#include + +// Cutlass includes +#include // F16 data type +#include +#include +#include + +// CuTe includes +#include // CuTe tensor implementation +#include // CuTe functions for querying the details of cluster launched +#include // Compile time in constants such as _1, _256 etc. +#include + +// Tutorial helpers +#include "example_utils.hpp" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tutorial 05: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where: +// - Matrix A is MxK, K-major (BLAS transpose T, row-major) +// - Matrix B is NxK, K-major (BLAS transpose N, column-major) +// - Matrices C and D are MxN, N-major (BLAS row-major) +// +// Key extensions to tutorial 04_mma_tma_2sm_sm100.cu: +// 1. Demonstrate using TMA instructions in the epilogue +// +// This GEMM kernel will perform the following steps: +// 1. Load A and B matrices from GMEM to SMEM using Multicasted TMA.2SM load operations. +// 2. Perform matrix multiply-accumulate (MMA) operations using 2SM tcgen05.mma instruction. +// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +// 4. Read C matrix from global memory (GMEM) to shared memory (SMEM) with TMA. +// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix. +// 6. Store D matrix from shared memory (SMEM) to global memory (GMEM) with TMA. +// +// SM100 2SM tcgen05.mma instructions operate as follows: +// - Mma is launched by only one SM +// With 2SM MMA instructions, only 1 of the 2 CTAs collaborating on MMA executes the instruction. +// We call the collaborating CTAs, peer CTAs. And the CTA executing the MMA instruction is called leader CTA. +// - Read matrix A from SMEM or TMEM +// - Read matrix B from SMEM +// - Write accumulator to TMEM +// The accumulator in TMEM must then be loaded to registers before writing back to GMEM. +// +// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types +// and the MMA's M and N dimensions. +// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors. +// These are the A and B fragments of the tcgen05.mma in CuTe terminology. +// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial. +// +// The MMA details: +// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 256x256x16 MMA +// operation. F32 accumulator type is chosen since both C and D matrices use F32. +// This example uses F16xF16 = F32 MMA where: +// TypeA = cutlass::half_t; // MMA A Data Type +// TypeB = cutlass::half_t; // MMA B Data Type +// TypeC = float; // MMA C Data Type +// TypeD = float; // MMA D Data Type +// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// The shared memory buffers for A, B, C, and D matrices. +template // EpiTile_M, EpiTile_N +struct SharedStorage +{ + alignas(128) union { + alignas(128) struct { + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; + } mainloop; + alignas(128) cute::ArrayEngine> C; + alignas(128) cute::ArrayEngine> D; + } tensors; + + alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM + alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM + + CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(tensors.mainloop.A.begin()), ASmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(tensors.mainloop.B.begin()), BSmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sC() { return make_tensor(make_smem_ptr(tensors.C.begin()), CSmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sD() { return make_tensor(make_smem_ptr(tensors.D.begin()), DSmemLayout{}); } +}; + +// The device kernel +template +__global__ static +void +gemm_device(ATensor mA, // (Gemm_M, Gemm_K) + BTensor mB, // (Gemm_N, Gemm_K) + CTensor mC, // (Gemm_M, Gemm_N) + DTensor mD, // (Gemm_M, Gemm_N) + MmaTiler_MNK mma_tiler, // + EpiTiler_MN epi_tiler_mn, // + TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K> + ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK) + CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A, + CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B, + CUTE_GRID_CONSTANT TmaAtomC const tma_atom_C, + CUTE_GRID_CONSTANT TmaAtomD const tma_atom_D, + Alpha alpha, Beta beta) +{ + // Step 1: The Prologue. + + // The CTA layout within the Cluster: (V,M,N,K) -> CTA idx + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename TiledMMA::AtomThrID{})); + + // Construct the MMA grid coordinate from the CTA grid coordinate + auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate + blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate + blockIdx.y, // MMA-N coordinate + _); // MMA-K coordinate + + // Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed + // by this mma tile. + // CuTe provides local_tile partitioning function. local_tile accepts 4 parameters: + // * Tensor to partition + // * Tiler to use for partitioning + // * Coordinate to use for slicing the partitioned tensor + // * Projection to ignore unwanted modes of the Tiler and Coordinate + auto mma_coord = select<1,2,3>(mma_coord_vmnk); + Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K) + Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K) + Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N) + + if (thread0()) { + print("mA:\t"); print(mA); print("\n"); // mA: ArithTuple(_0,_0) o (512,256):(_1@1,_1@0) + print("mB:\t"); print(mB); print("\n"); // mB: ArithTuple(_0,_0) o (1024,256):(_1@1,_1@0) + print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1) + print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1) + + print("gA:\t"); print(gA); print("\n"); // gA: ArithTuple(_0,0) o (_128,_64,4):(_1@1,_1@0,_64@0) + print("gB:\t"); print(gB); print("\n"); // gB: ArithTuple(_0,0) o (_256,_64,4):(_1@1,_1@0,_64@0) + print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1) + print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1) + } __syncthreads(); + + // The SMEM tensors + + // Allocate SMEM + extern __shared__ char shared_memory[]; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Represent the SMEM buffers for A and B + Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // + // Mma partitioning for A and B + // + + auto mma_v = get<0>(mma_coord_vmnk); + ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate + Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K) + Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N) + Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) + } __syncthreads(); + + // MMA Fragment Allocation + // We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations. + // For tcgen05.mma operations: + // - Matrices A and B are sourced from SMEM + // - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively + // - The first mode of each descriptor represents the SMEM for a single MMA operation + Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) + Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + + // TMEM Allocation + // On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM). + // ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator. + Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N) + + if (thread0()) { + print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) + print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) + } __syncthreads(); + + // TMA Setup + // + // These are TMA partitionings, which have a dedicated custom partitioner. + // In this example, the TMA multicasts the loads across multiple CTAs. + // Loads of A are multicasted along the N dimension of the cluster_shape_VMNK and + // Loads of B are multicasted along the M dimension of the cluster_shape_VMNK. + // Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host. + // For A tensor: The group_modes<0,3> transforms the (MmaA, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaA, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile MK. + // For B tensor: The group_modes<0,3> transforms the (MmaB, NumMma_M, NumMma_K, Tiles_K)-shaped tensor + // into ((MmaB, NumMma_M, NumMma_K), Tiles_K). The partitioning only pays attention to mode-0, the MMA Tile NK. + // Simply put, the TMA will be responsible for everything in mode-0 with a single call to cute::copy. + // The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info. + + // Each CTA with the same m-coord will load a portion of A + // Each CTA with the same n-coord will load a portion of B + // Computation of the multicast masks must take into account the Peer CTA for TMA.2SM + + // Construct the CTA-in-Cluster coordinate for multicasting + auto cta_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(int(cute::block_rank_in_cluster())); + + // Project the cluster_layout for tma_A along the N-modes + auto [tAgA, tAsA] = tma_partition(tma_atom_A, + get<2>(cta_in_cluster_coord_vmnk), // The CTA coordinate along N mode of the cluster + make_layout(size<2>(cluster_layout_vmnk)), // The CTA layout along N mode of the cluster + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + // Project the cluster_layout for tma_B along the M-modes + auto [tBgB, tBsB] = tma_partition(tma_atom_B, + get<1>(cta_in_cluster_coord_vmnk), // The CTA coordinate along M mode of the cluster + make_layout(size<1>(cluster_layout_vmnk)), // The CTA layout along M mode of the cluster + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + // Project the cluster_layout and cta_coord along the N-mode to determine the multicast mask for A + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + // Project the cluster_layout and cta_coord along the M-mode to determine the multicast mask for B + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + // Project the cluster_layout and cta_coord along the VM + VN-modes to determine the multicast mask for C + uint16_t mma_mcast_mask_c = create_tma_multicast_mask<0,1>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk) | + create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); + + // Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM + int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA)) + + size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB)); + + if (thread0()) { + print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0)) + print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0) + print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0)) + printf("tma_transaction_bytes: %d\n", tma_transaction_bytes); + printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a); + printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b); + printf("mma_mcast_mask_c: %x\n", mma_mcast_mask_c); + } __syncthreads(); + + // Barrier Initialization + auto elect_one_thr = cute::elect_one_sync(); + auto elect_one_warp = (threadIdx.x / 32 == 0); + auto elect_one_cta = get<0>(cta_in_cluster_coord_vmnk) == Int<0>{}; + + // Barriers in SMEM should be initialized by a single thread. + if (elect_one_warp && elect_one_thr) { + // The number of CTAs that participates in multicast operation with this CTA (for both A and B matrices) + int num_mcast_participants = size<1>(cluster_layout_vmnk) + size<2>(cluster_layout_vmnk) - 1; + cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ num_mcast_participants); + cute::initialize_barrier(shared_storage.tma_barrier, /* num_threads */ 1); + } + int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + int tma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit. + cute::cluster_sync(); // Make sure all CTAs in Cluster observe barrier init and TMEM alloc. + + // Step 2: The Mainloop. + + // Set mma accumlate option to zero so that the first MMA instruction will clear the TMEM accumulator. + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + // Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM + for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) + { + // Step 2a: Load A and B tiles + + // TMA Load Operations: + // - Execute asynchronous TMA loads with single thread + // - Both peer and leader CTAs initiate TMA loads + // - Set expected transaction bytes. For 2SM TMA instructions, the transaction bytes counts both CTAs. + // - Although TMAs are initiated by both peer and leader CTAs, the barrier is only set and waited by the leader CTA. + // - Initiate asynchronous transfers with a multicast mask that includes all CTAs that participate in multicast. + if (elect_one_warp && elect_one_thr) { // TMA loads are executed by one thread + if (elect_one_cta) { // Only the leader CTA waits for TMA transactions + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); // Set the expected transaction bytes for the TMA loads + } + copy(tma_atom_A.with(shared_storage.tma_barrier,tma_mcast_mask_a), tAgA(_,k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile + copy(tma_atom_B.with(shared_storage.tma_barrier,tma_mcast_mask_b), tBgB(_,k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile + } + + // Step 2b: Execute the MMAs for this tile + + if (elect_one_cta) { + // Wait for TMA loads to complete on leader CTAs + cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit); + tma_barrier_phase_bit ^= 1; + + // tcgen05.mma instructions require single-thread execution: + // - Only one warp performs the MMA-related loop operations + // - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp + // - No explicit elect_one_sync region is needed from the user + if (elect_one_warp) { + // Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + // Ensure MMAs are completed, only then we can reuse the A and B SMEM. + cutlass::arch::umma_arrive_multicast_2x1SM(&shared_storage.mma_barrier, mma_mcast_mask_c); // All multicasting CTAs encoded in mask. + } + } + // Wait MMAs to complete to avoid overwriting the A and B SMEM. + cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); + mma_barrier_phase_bit ^= 1; + } + + // Step 3: The Epilogue. + + // Apply rank-2 epilogue tiler to rank-2 MMA_V mode + auto epi_tiler_v = make_tile(epi_tiler_mn); // (EpiTile) + Tensor tAcc_epi = zipped_divide(tCtAcc, epi_tiler_v); // (EpiTile,NumTiles) + Tensor gC_epi = zipped_divide(tCgC, epi_tiler_v); // (EpiTile,NumTiles) + Tensor gD_epi = zipped_divide(tCgD, epi_tiler_v); // (EpiTile,NumTiles) + + // Construct corresponding SMEM tensors + Tensor sC_epi = shared_storage.tensor_sC(); // (EpiTile) + Tensor sD_epi = shared_storage.tensor_sD(); // (EpiTile) + + // Partition for TMA + auto [tGS_gC, tGS_sC] = tma_partition(tma_atom_C, sC_epi, gC_epi); // (GMEM -> SMEM) + auto [tSG_gD, tSG_sD] = tma_partition(tma_atom_D, sD_epi, gD_epi); // (SMEM -> GMEM) + + // Reset transaction bytes for C load + tma_transaction_bytes = sizeof(make_tensor_like(tGS_sC)); + + // Partition for TMEM accumulators load (TMEM -> RMEM) + TiledCopy t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tAcc_epi(_,_0{})); + ThrCopy thr_t2r = t2r_copy.get_slice(threadIdx.x); + Tensor tTR_tAcc = thr_t2r.partition_S(tAcc_epi); // (TmemCpy,NumTmemCpy,NumTiles) + Tensor tTR_sC = thr_t2r.partition_D(sC_epi); // (TmemCpy,NumTmemCpy) + Tensor tTR_sD = thr_t2r.partition_D(sD_epi); // (TmemCpy,NumTmemCpy) + // Allocate register tensors + Tensor tTR_rC = make_tensor_like(tTR_sC); // (TmemCpy,NumTmemCpy) + Tensor tTR_rD = make_fragment_like(tTR_sD); // (TmemCpy,NumTmemCpy) + + // Loop over the epilogue tiles + CUTE_UNROLL + for (int epi_tile_idx = 0; epi_tile_idx < size<2>(tTR_tAcc); ++epi_tile_idx) { + // TMA Load C: GMEM -> SMEM + if (elect_one_warp && elect_one_thr) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); + copy(tma_atom_C.with(shared_storage.tma_barrier, 0 /*no multicast*/), tGS_gC(_,epi_tile_idx), tGS_sC); + } + // All threads wait for C TMA load to complete + cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit); + tma_barrier_phase_bit ^= 1; + + // Load C: SMEM -> RMEM + copy_aligned(tTR_sC, tTR_rC); + + // Load Acc: TMEM -> RMEM + copy(t2r_copy, tTR_tAcc(_,_,epi_tile_idx), tTR_rD); + + // Compute D = beta * C + alpha * (A*B) + axpby(beta, tTR_rC, alpha, tTR_rD); + + // Store D: RMEM -> SMEM + __syncthreads(); // Ensure C loads are finished before reusing smem (unnecessary if smem layouts match) + copy_aligned(tTR_rD, tTR_sD); + + // TMA Store D: SMEM -> GMEM + tma_store_fence(); // Ensure D smem stores are visible to TMA + __syncthreads(); // Ensure all threads have issued fence + if (elect_one_warp && elect_one_thr) { + copy(tma_atom_D, tSG_sD, tSG_gD(_,epi_tile_idx)); + tma_store_arrive(); // issuing thread commits D TMA store + tma_store_wait<0>(); // issuing thread waits for D TMA store to complete + } + __syncthreads(); // All threads sync with issuing thread + } +} + +template +void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, + TypeB const* device_ptr_B, LayoutB layout_B, + TypeC const* device_ptr_C, LayoutC layout_C, + TypeD * device_ptr_D, LayoutD layout_D, + Alpha const alpha, Beta const beta) +{ + assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M + assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M + assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N + assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N + assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K + + // Represent the full tensors in global memory + Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K) + Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K) + Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N) + Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N) + + // Get M, N, K dimensions of the GEMM we are running + auto Gemm_M = shape<0>(layout_A); + auto Gemm_N = shape<0>(layout_B); + auto Gemm_K = shape<1>(layout_A); + std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl; + + //////////////////////////////////////////////////////////// + // + // Initialize the GEMM kernel parameters + // + //////////////////////////////////////////////////////////// + + // Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a + // larger TiledMma from the given mma instruction. + // See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions + TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_2x1SM_SS{}); // A and B layouts + + // We can also print and inspect the tiled_mma + print(tiled_mma); + // TiledMMA + // ThrLayoutVMNK: (_2,_1,_1,_1):(_1,_0,_0,_0) + // PermutationMNK: (_,_,_) + // MMA_Atom + // ThrID: _2:_1 + // Shape_MNK: (_256,_256,_16) // MmaM, MmaN, MmaK (MmaK is constant for each instr.) + // LayoutA_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for A matrix + // LayoutB_TV: (_2,(_128,_16)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + // LayoutC_TV: (_2,(_128,_256)):(_128,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix + + // Define MMA tiler sizes (static) + auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M. + auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M. + auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16. + auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K) + + // In SM90, the MMAs are CTA-local and perform thread-level partitioning. + // In SM100, the MMAs are Cluster-local and perform CTA-level partitioning. + // Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA + // and SM100 uses a mma_tiler to extract portions of the Problem for the MMA. + // The MMA's partitioning then yeilds the CTA-local work. + + if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) { + std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl; + return; + } + + if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) { + std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl; + return; + } + + // + // Determine the SMEM layouts: + // + + // * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions. + // * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape. + // These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3 + // where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time + // MMA instr is repeated in M/N mode and K mode of MMA tile, respectively. + // * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch. + + // Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K) + auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); + // Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K) + auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); + + // Print and inspect mma_shape_A, and mma_shape_B for this example. + print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4) + print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4) + + // A and B tensors are swizzled in SMEM to improve MMA performance. + // * However, expressing swizzled layouts is very hard. + // * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes + auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); + auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); + + // Print and inspect sA_layout and sB_layout for this example. + print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) + print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + + // + // Epilogue parameters + // + + // Pre-partitioned Tile Shape (MmaTile_M, MmaTile_N) to post-partitioned ((MmaM,MmaN), NumMma_M, NumMma_N) + auto mma_shape_C = partition_shape_C(tiled_mma, make_shape(size<0>(mma_tiler), size<1>(mma_tiler))); + + // For TMA epilogue performance it may be beneficial to iterate over the output in smaller tiles than the MMA tile + auto epi_tiler = make_tile(size<0,0>(mma_shape_C), size<0,1>(mma_shape_C) / Int<4>{}); // 4 TMA copies per CTA per MMA tile + + // SMEM layouts for C and D should match the epilogue tile + auto sC_layout_mn = tile_to_shape(UMMA::Layout_K_SW128_Atom{}, // MMA K-major is equivalent to epilogue N-major + make_shape(size<0>(epi_tiler), size<1>(epi_tiler))); + auto sC_layout = group<0,2>(sC_layout_mn); // Group modes for tma_partition + + auto sD_layout_mn = tile_to_shape(UMMA::Layout_K_SW128_Atom{}, // MMA K-major is equivalent to epilogue N-major + make_shape(size<0>(epi_tiler), size<1>(epi_tiler))); + auto sD_layout = group<0,2>(sD_layout_mn); // Group modes for tma_partition + + print("sC_layout:\t"); print(sC_layout); print("\n"); // sC_layout: Sw<3,4,3> o smem_ptr[32b](unset) o ((_8,_16),(_32,_2)):((_32,_256),(_1,_4096)) + print("sD_layout:\t"); print(sD_layout); print("\n"); // sD_layout: Sw<3,4,3> o smem_ptr[32b](unset) o ((_8,_16),(_32,_2)):((_32,_256),(_1,_4096)) + + // Now we can find the SMEM allocation size + using SMEMStorage = SharedStorage; + + // + // TMA Descriptor Creation (Host Side) + // + + // The cluster shape and layout + auto cluster_shape = make_shape(Int<4>{}, Int<4>{}, Int<1>{}); + Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), + make_tile(typename decltype(tiled_mma)::AtomThrID{})); + + // SM100 interface for creating TMA loads. + Copy_Atom tma_atom_A = make_tma_atom_A_sm100( + SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction. + mA, // Source GMEM tensor + sA_layout, // Destination SMEM layout + mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes. + tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning. + cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed. + // We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode. + Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); // (Gemm_M, Gemm_K) + + print("tma_atom_A:\t"); print(tma_atom_A); print("\n"); + // tma_atom_A: Copy_Atom + // ThrID: _2:_1 + // ValLayoutSrc: (_2,_8192):(_8192,_1) + // ValLayoutDst: (_2,_8192):(_8192,_1) + // ValLayoutRef: (_2,_8192):(_8192,_1) + // ValueType: 16b + + // SM100 interface for creating TMA loads. + Copy_Atom tma_atom_B = make_tma_atom_B_sm100( + SM100_TMA_2SM_LOAD_MULTICAST{}, // TMA load operation -- Multicasting 2SM instruction. + mB, // Source GMEM tensor + sB_layout, // Destination SMEM layout + mma_tiler, // MmaTiler_MNK. Unlike Sm90 interface where the tiler only included M and K modes. + tiled_mma, // Sm100 also requires the TiledMma to perform CTA-level partitioning. + cluster_layout_vmnk); // ClusterLayout_VMNK. Unlike Sm90 interface where only the multicasting mode is passed. + // We have make_tma_atom_[A|B]_sm100 and which determines the multicast mode. + Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // (Gemm_N, Gemm_K) + + print("tma_atom_B:\t"); print(tma_atom_B); print("\n"); + // tma_atom_B: Copy_Atom + // ThrID: _2:_1 + // ValLayoutSrc: (_2,_8192):(_8192,_1) + // ValLayoutDst: (_2,_8192):(_8192,_1) + // ValLayoutRef: (_2,_8192):(_8192,_1) + // ValueType: 16b + + Copy_Atom tma_atom_C = make_tma_atom( + SM90_TMA_LOAD{}, // TMA load operation + mC, // Source GMEM tensor + sC_layout, // Destination SMEM layout + epi_tiler); // MN Tiler for epilogue + Tensor mC_tma = tma_atom_C.get_tma_tensor(shape(mC)); // (Gemm_M, Gemm_N) + + print("tma_atom_C:\t"); print(tma_atom_C); print("\n"); + // tma_atom_C: Copy_Atom + // ThrID: _1:_0 + // ValLayoutSrc: (_1,_4096):(_0,_1) + // ValLayoutDst: (_1,_4096):(_0,_1) + // ValLayoutRef: (_1,_4096):(_0,_1) + // ValueType: 32b + + Copy_Atom tma_atom_D = make_tma_atom( + SM90_TMA_STORE{}, // TMA store operation + mD, // Destination GMEM tensor + sD_layout, // Source SMEM layout + epi_tiler); // MN Tiler for epilogue + Tensor mD_tma = tma_atom_D.get_tma_tensor(shape(mD)); // (Gemm_M, Gemm_N) + + print("tma_atom_D:\t"); print(tma_atom_D); print("\n"); + // tma_atom_D: Copy_Atom + // ThrID: _1:_0 + // ValLayoutSrc: (_1,_4096):(_0,_1) + // ValLayoutDst: (_1,_4096):(_0,_1) + // ValLayoutRef: (_1,_4096):(_0,_1) + // ValueType: 32b + + //////////////////////////////////////////////////////////// + // + // Launch GEMM kernel + // + //////////////////////////////////////////////////////////// + + dim3 dimBlock(128); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), + round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + int smemBytes = sizeof(SMEMStorage); + + auto* kernel_ptr = &gemm_device; + + // Set kernel attributes (set SMEM) + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); + + printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z); + printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, + mA_tma, mB_tma, mC_tma, mD_tma, + mma_tiler, epi_tiler, tiled_mma, cluster_shape, + tma_atom_A, tma_atom_B, tma_atom_C, tma_atom_D, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if ((props.major != 10) || (props.major == 10 && props.minor > 1)) { + std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl; + std::cerr << " Found " << props.major << "." << props.minor << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + int Gemm_M = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &Gemm_M); + + int Gemm_N = 1024; + if (argc >= 3) + sscanf(argv[2], "%d", &Gemm_N); + + int Gemm_K = 256; + if (argc >= 4) + sscanf(argv[3], "%d", &Gemm_K); + + //////////////////////////////////////////////////////////// + // + // Create A, B, C, and D tensors + // + //////////////////////////////////////////////////////////// + // Define the data types. A and B types are same for MMA instruction. + using TypeA = cutlass::half_t; // MMA A Data Type + auto type_str_a = "half_t"; + using TypeB = cutlass::half_t; // MMA B Data Type + auto type_str_b = "half_t"; + using TypeC = float; // MMA C Data Type + [[maybe_unused]] auto type_str_c = "float"; + using TypeD = float; // MMA D Data Type + auto type_str_d = "float"; + using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type. + + // A tensor MxK K-major (Layout T = Row-Major) + Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1) + // B tensor NxK K-major (Layout N = Column-Major) + Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K), + make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1) + // C tensor MxN N-major (Layout T = Row-Major) + Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + // D tensor MxN N-major (Layout T = Row-Major) + Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N), + make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1) + + // Host allocations and host CuTe tensors for A, B, and C tensors. + thrust::host_vector host_A(Gemm_M * Gemm_K); + Tensor host_tensor_A = make_tensor(host_A.data(), layout_A); + print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1) + + thrust::host_vector host_B(Gemm_N * Gemm_K); + Tensor host_tensor_B = make_tensor(host_B.data(), layout_B); + print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1) + + thrust::host_vector host_C(Gemm_M * Gemm_N); + Tensor host_tensor_C = make_tensor(host_C.data(), layout_C); + print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1) + + // Note that we don't need a host_tensor for D yet. + thrust::device_vector device_D(Gemm_M * Gemm_N); + + // Initialize A, B, and C tensors with random values. + initialize_tensor(host_tensor_A); + initialize_tensor(host_tensor_B); + initialize_tensor(host_tensor_C); + + // Copy A, B, and C tensors from host memory to device memory + thrust::device_vector device_A = host_A; + thrust::device_vector device_B = host_B; + thrust::device_vector device_C = host_C; + + using Alpha = float; + using Beta = float; + Alpha alpha = 1.0f; + Beta beta = 0.0f; + // Setup input and output tensors, and the kernel parameters; and execute the kernel on device + gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A, + device_B.data().get(), layout_B, + device_C.data().get(), layout_C, + device_D.data().get(), layout_D, + alpha, beta); + // Host allocation for D tensor and transfer D tensor from device to host + thrust::host_vector host_D = device_D; + // Create a non-owning CuTe tensor for D tensor + Tensor host_tensor_D = make_tensor(host_D.data(), layout_D); + + //////////////////////////////////////////////////////////// + // + // Execute reference GEMM kernel + // + //////////////////////////////////////////////////////////// + + thrust::host_vector host_reference_D(Gemm_M*Gemm_N); + auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D); + reference_gemm(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta); + + //////////////////////////////////////////////////////////// + // + // Compare results + // + //////////////////////////////////////////////////////////// + + auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A, + type_str_b, host_tensor_B, + type_str_d, host_tensor_D, host_reference_tensor_D); + bool success = relative_error <= 0.0; + std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl; +#else + std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; +} diff --git a/examples/cute/tutorial/blackwell/CMakeLists.txt b/examples/cute/tutorial/blackwell/CMakeLists.txt new file mode 100644 index 00000000..35db1ec4 --- /dev/null +++ b/examples/cute/tutorial/blackwell/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) +cutlass_example_add_executable( + cute_tutorial_01_mma_sm100 + 01_mma_sm100.cu +) + +cutlass_example_add_executable( + cute_tutorial_02_mma_tma_sm100 + 02_mma_tma_sm100.cu +) + +cutlass_example_add_executable( + cute_tutorial_03_mma_tma_multicast_sm100 + 03_mma_tma_multicast_sm100.cu +) + +cutlass_example_add_executable( + cute_tutorial_04_mma_tma_2sm_sm100 + 04_mma_tma_2sm_sm100.cu +) + +cutlass_example_add_executable( + cute_tutorial_05_mma_tma_epi_sm100 + 05_mma_tma_epi_sm100.cu +) +endif() diff --git a/examples/cute/tutorial/blackwell/example_utils.hpp b/examples/cute/tutorial/blackwell/example_utils.hpp new file mode 100644 index 00000000..f6332002 --- /dev/null +++ b/examples/cute/tutorial/blackwell/example_utils.hpp @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include // CuTe tensor implementation +#include + +template +void +reference_gemm(TensorA const& tensor_A, TensorB const& tensor_B, + TensorC const& tensor_C, TensorD & tensor_D, + Alpha alpha, Beta beta) +{ + using namespace cute; + for (int m = 0; m < size<0>(tensor_D); ++m) { + for (int n = 0; n < size<1>(tensor_D); ++n) { + AccType c = AccType(0.f); + for (int k = 0; k < size<1>(tensor_A); ++k) { + c += tensor_A(m,k) * tensor_B(n,k); + } + tensor_D(m,n) = alpha * c + beta * tensor_C(m,n); + } + } +} + +template +bool +compare_results(TensorA const& tensor_A, TensorB const& tensor_B, + TensorC const& tensor_C, TensorD const& tensor_D, + RefTensorD const& ref_tensor_D, + bool print_diff = false) +{ + using namespace cute; + auto norm_A = matrix_inf_norm(tensor_A); + auto norm_B = matrix_inf_norm(tensor_B); + auto norm_C = matrix_inf_norm(tensor_C); + auto norm_D = matrix_inf_norm(tensor_D); + auto norm_ref_D = matrix_inf_norm(ref_tensor_D); + auto norm_diff = matrix_diff_inf_norm(tensor_D, ref_tensor_D); + + if (print_diff) { + for (int m = 0; m < size<0>(tensor_D); ++m) { + for (int n = 0; n < size<1>(tensor_D); ++n) { + std::cout << m << "," << n << " : " << tensor_D(m,n) << " vs. " << ref_tensor_D(m,n) << std::endl; + } + } + } + + std::cout << "norm (A) : " << norm_A.inf_norm << std::endl; + std::cout << "norm (B) : " << norm_B.inf_norm << std::endl; + std::cout << "norm (C) : " << norm_C.inf_norm << std::endl; + std::cout << "norm (D) : " << norm_D.inf_norm << std::endl; + std::cout << "norm (ref_D) : " << norm_ref_D.inf_norm << std::endl; + std::cout << "norm (D-ref_D) : " << norm_diff.inf_norm << std::endl; + + return (!norm_A.found_nan) && (!norm_B.found_nan) && + (!norm_C.found_nan) && (!norm_D.found_nan) && (!norm_ref_D.found_nan) && // There are no NaNs + (norm_A.inf_norm > 0.0) && (norm_B.inf_norm > 0.0) && + (norm_C.inf_norm > 0.0) && (norm_D.inf_norm > 0.0) && (norm_ref_D.inf_norm > 0.0) && // Values in tensors aren't zeros + (norm_diff.inf_norm <= 0.0); // Diff (ref_D-D) == 0 +} + +template +void +initialize_tensor(Tensor& tensor, cute::tuple value_range = {-2, 2}) +{ + using DataType = typename Tensor::element_type; + auto [min, max] = value_range; + for (int i = 0; i < cute::size(tensor); i++) { + tensor(i) = DataType(int((max-min)*(rand() / double(RAND_MAX)) + min)); + } +} diff --git a/examples/cute/tutorial/hopper/CMakeLists.txt b/examples/cute/tutorial/hopper/CMakeLists.txt new file mode 100644 index 00000000..7498090e --- /dev/null +++ b/examples/cute/tutorial/hopper/CMakeLists.txt @@ -0,0 +1,38 @@ + +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + cute_tutorial_wgmma_sm90 + wgmma_sm90.cu +) + +cutlass_example_add_executable( + cute_tutorial_wgmma_tma_sm90 + wgmma_tma_sm90.cu +) diff --git a/examples/cute/tutorial/hopper/wgmma_sm90.cu b/examples/cute/tutorial/hopper/wgmma_sm90.cu new file mode 100644 index 00000000..405bb310 --- /dev/null +++ b/examples/cute/tutorial/hopper/wgmma_sm90.cu @@ -0,0 +1,611 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/cluster_launch.hpp" + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +using namespace cute; + +template // (N,K,P) +struct SharedStorage +{ + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; +}; + +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, + TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, + TC * C, CStride dC, TiledMma mma, + Alpha alpha, Beta beta) +{ + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads + + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK + CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) + Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) + Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory tensors + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), ASmemLayout{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), BSmemLayout{}); // (BLK_N,BLK_K,PIPE) + + // + // Partition the copying of A and B tiles across the threads + // + + ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); + Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) + Tensor sA_ = as_position_independent_swizzle_tensor(sA); + Tensor tAsA = thr_copy_a.partition_D(sA_); // (CPY,CPY_M,CPY_K,PIPE) + + ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); + Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) + Tensor sB_ = as_position_independent_swizzle_tensor(sB); + Tensor tBsB = thr_copy_b.partition_D(sB_); // (CPY,CPY_N,CPY_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K + + // + // PREFETCH + // + + // auto K_PIPE_MAX = size<3>(tAsA); + + // // Total count of tiles + // int k_tile_count = size<3>(tAgA); + // // Current tile index in gmem to read from + // int k_tile_next = 0; + + // // Start async loads for all pipes but the last + // CUTE_UNROLL + // for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) { + // copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe)); + // copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe)); + // cp_async_fence(); + // --k_tile_count; + // if (k_tile_count > 0) { ++k_tile_next; } + // } + + // + // Define A/B partitioning and C accumulators + // + + ThrMMA thr_mma = mma.get_slice(threadIdx.x); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate registers for pipelining + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + // Allocate the accumulators -- same size as the projected data + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M + CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N + CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K + + // Clear the accumulators + clear(tCrC); + +#if 0 + if(thread0()) { + print(" mA : "); print( mA); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mB : "); print( mB); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sB : "); print( sB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tBsB : "); print(tBsB); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mC : "); print( mC); print("\n"); + print(" gC : "); print( gC); print("\n"); + print("tCsA : "); print(tCsA); print("\n"); + print("tCsB : "); print(tCsB); print("\n"); + print("tCgC : "); print(tCgC); print("\n"); + print("tCrA : "); print(tCrA); print("\n"); + print("tCrB : "); print(tCrB); print("\n"); + print("tCrC : "); print(tCrC); print("\n"); + } +#endif + +#if 1 + + // Total number of k-tiles + auto K_TILE_MAX = size<3>(tAgA); + // Number of pipelined k-tiles in smem + auto K_PIPE_MAX = size<3>(tAsA); + + // + // PREFETCH + // + + // Prefetch all but the last + CUTE_UNROLL + for (int k = 0; k < K_PIPE_MAX-1; ++k) + { + copy(copy_a, tAgA(_,_,_,k), tAsA(_,_,_,k)); + copy(copy_b, tBgB(_,_,_,k), tBsB(_,_,_,k)); + cp_async_fence(); + } + + // Clear the accumulators + clear(tCrC); + + __syncthreads(); + + // + // PIPELINED MAIN LOOP + // + + // Current pipe to read from + int k_pipe_read = 0; + // Current pipe to write to + int k_pipe_write = K_PIPE_MAX-1; + + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) + { + int k_tile_next = k_tile + (K_PIPE_MAX-1); + k_tile_next = (k_tile_next >= K_TILE_MAX) ? K_TILE_MAX-1 : k_tile_next; + + // + // Copy gmem to smem for k_tile_write + // + + copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe_write)); + copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe_write)); + cp_async_fence(); + + // Advance k_pipe_write + ++k_pipe_write; + k_pipe_write = (k_pipe_write == K_PIPE_MAX) ? 0 : k_pipe_write; + + // + // Compute on k_tile + // + + // Wait on all cp.async -- optimize by pipelining to overlap GMEM reads + cp_async_wait<0>(); + + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(mma, tCrA(_,_,_,k_pipe_read), tCrB(_,_,_,k_pipe_read), tCrC); + warpgroup_commit_batch(); + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait<0>(); + warpgroup_fence_operand(tCrC); + + // Advance k_pipe_read + ++k_pipe_read; + k_pipe_read = (k_pipe_read == K_PIPE_MAX) ? 0 : k_pipe_read; + } + +#endif + + // + // Epilogue + // + + axpby(alpha, tCrC, beta, tCgC); +} + +// Setup params for a NT GEMM +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define NT strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(bN,bK,bP)); + + // Define the thread layouts (static) + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout>{}, // Thr layout 32x4 m-major + Layout>{});// Val layout 8x1 m-major + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout>{}, // Thr layout 32x4 n-major + Layout>{});// Val layout 8x1 n-major + + TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS{}); + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + // + // Setup and Launch + // + + // Launch parameter setup + dim3 dimBlock(size(tiled_mma)); + dim3 dimCluster(1, 1, 1); + dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), + round_up(size(ceil_div(n, bN)), dimCluster.y)); + int smemBytes = sizeof(SharedStorage); + + auto* kernel_ptr = &gemm_device; + + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); + + // Kernel Launch + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, + prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, tiled_mma, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +// Setup params for a TN GEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom{}, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom{}, make_shape(bN,bK,bP)); + + // Define the thread layouts (static) + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 16x8 k-major + Layout>{}); // Val layout 1x8 + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 16x8 k-major + Layout>{}); // Val layout 1x8 + + TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS{}); + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + // + // Setup and Launch + // + + // Launch parameter setup + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(tiled_mma)); + dim3 dimCluster(1, 1, 1); + dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), + round_up(size(ceil_div(n, bN)), dimCluster.y)); + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + + void const* kernel_ptr = reinterpret_cast( + &gemm_device); + + CUTE_CHECK_ERROR(cudaFuncSetAttribute( + kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + // Kernel Launch + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr, + prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, tiled_mma, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 8) { + std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl; + // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED) + + int m = 5120; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 5120; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 4096; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = cute::half_t; + using TB = cute::half_t; + using TC = cute::half_t; + using TI = cute::half_t; + + TI alpha = TI(1.0f); + TI beta = TI(0.0f); + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + // Initialize the tensors + for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1)); + for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1)); + for (int j = 0; j < m*n; ++j) h_C[j] = TC(0); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + +#else + std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; +} diff --git a/examples/cute/tutorial/wgmma_sm90.cu b/examples/cute/tutorial/hopper/wgmma_tma_sm90.cu similarity index 92% rename from examples/cute/tutorial/wgmma_sm90.cu rename to examples/cute/tutorial/hopper/wgmma_tma_sm90.cu index eb634e23..77a30890 100644 --- a/examples/cute/tutorial/wgmma_sm90.cu +++ b/examples/cute/tutorial/hopper/wgmma_tma_sm90.cu @@ -55,8 +55,8 @@ template // (N,K,P) struct SharedStorage { - array_aligned> smem_A; - array_aligned> smem_B; + alignas(128) cute::ArrayEngine> A; + alignas(128) cute::ArrayEngine> B; uint64_t tma_barrier[size<2>(SmemLayoutA{})]; uint64_t mma_barrier[size<2>(SmemLayoutA{})]; @@ -110,8 +110,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& smem = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // // Partition the copying of A and B tiles @@ -132,8 +132,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE) // The TMA is responsible for copying everything in mode-0 of tAsA and tBsB - constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) + - CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB); + constexpr int tma_transaction_bytes = sizeof(make_tensor_like(tensor<0>(tAsA))) + + sizeof(make_tensor_like(tensor<0>(tBsB))); // // PREFETCH @@ -171,7 +171,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, if ((warp_idx == 0) && lane_predicate) { // Set expected Tx Bytes after each reset / init - ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes); + ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], tma_transaction_bytes); copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe)); copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe)); } @@ -242,7 +242,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, // Wait for Consumer to complete consumption ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase()); // Set expected Tx Bytes after each reset / init - ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes); + ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], tma_transaction_bytes); copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe)); copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe)); ++write_state; @@ -393,27 +393,25 @@ gemm_tn(int m, int n, int k, // // Launch parameter setup - int smem_size = int(sizeof(SharedStorage)); dim3 dimBlock(size(tiled_mma)); dim3 dimCluster(2, 1, 1); dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), round_up(size(ceil_div(n, bN)), dimCluster.y)); - cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + int smemBytes = sizeof(SharedStorage); - void const* kernel_ptr = reinterpret_cast( - &gemm_device); + auto* kernel_ptr = &gemm_device; - CUTE_CHECK_ERROR(cudaFuncSetAttribute( - kernel_ptr, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); + CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smemBytes)); // Kernel Launch - cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr, + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr, prob_shape, cta_tiler, A, tmaA, B, tmaB, @@ -448,8 +446,10 @@ gemm(char transA, char transB, int m, int n, int k, int main(int argc, char** argv) { - cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; @@ -461,7 +461,7 @@ int main(int argc, char** argv) return 0; } -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM90A_SUPPORTED) int m = 512; if (argc >= 2) @@ -553,10 +553,8 @@ int main(int argc, char** argv) printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); #else - - std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; + std::cout << "CUTLASS_ARCH_MMA_SM90A_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; #endif return 0; - } diff --git a/examples/cute/tutorial/sgemm_sm80.cu b/examples/cute/tutorial/sgemm_sm80.cu index bcc31a0a..50914548 100644 --- a/examples/cute/tutorial/sgemm_sm80.cu +++ b/examples/cute/tutorial/sgemm_sm80.cu @@ -41,17 +41,27 @@ #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/helper_cuda.hpp" +template +struct SharedStorage +{ + cute::ArrayEngine> A; + cute::ArrayEngine> B; +}; + template __global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, - TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, - TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, + TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, S2RAtomA s2r_atom_a, + TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, S2RAtomB s2r_atom_b, TC * C, CStride dC, CSmemLayout , TiledMma mma, Alpha alpha, Beta beta) { @@ -95,10 +105,11 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) // Shared memory buffers - __shared__ TA smemA[cosize_v]; - __shared__ TB smemB[cosize_v]; - Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE) + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K,PIPE) // // Partition the copying of A and B tiles across the threads @@ -143,26 +154,35 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, // ThrMMA thr_mma = mma.get_slice(threadIdx.x); - Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) // Allocate registers for pipelining - Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K) - Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K) + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) // Allocate the accumulators -- same size as the projected data Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) - CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K) - CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K) CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) - CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M - CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N - CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K + CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCrA))); // MMA_M + CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCrB))); // MMA_N // Clear the accumulators clear(tCrC); + // + // Copy Atom retiling + // + + TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma); + ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(threadIdx.x); + Tensor tXsA = s2r_thr_copy_a.partition_S(sA); // (CPY,MMA_M,MMA_K,PIPE) + Tensor tXrA = s2r_thr_copy_a.retile_D(tCrA); // (CPY,MMA_M,MMA_K) + + TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma); + ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(threadIdx.x); + Tensor tXsB = s2r_thr_copy_b.partition_S(sB); // (CPY,MMA_N,MMA_K,PIPE) + Tensor tXrB = s2r_thr_copy_b.retile_D(tCrB); // (CPY,MMA_N,MMA_K) + #if 0 if(thread0()) { print(" mA : "); print( mA); print("\n"); @@ -187,12 +207,15 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, if(thread0()) { print(" mC : "); print( mC); print("\n"); print(" gC : "); print( gC); print("\n"); - print("tCsA : "); print(tCsA); print("\n"); - print("tCsB : "); print(tCsB); print("\n"); print("tCgC : "); print(tCgC); print("\n"); print("tCrA : "); print(tCrA); print("\n"); print("tCrB : "); print(tCrB); print("\n"); print("tCrC : "); print(tCrC); print("\n"); + + print("tXsA : "); print(tXsA); print("\n"); + print("tXrA : "); print(tXrA); print("\n"); + print("tXsB : "); print(tXsB); print("\n"); + print("tXrB : "); print(tXrB); print("\n"); } #endif @@ -204,8 +227,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, int smem_pipe_write = K_PIPE_MAX-1; // Pipe slice - Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); - Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + Tensor tXsA_p = tXsA(_,_,_,smem_pipe_read); + Tensor tXsB_p = tXsB(_,_,_,smem_pipe_read); // Size of the register pipeline auto K_BLOCK_MAX = size<2>(tCrA); @@ -217,8 +240,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, __syncthreads(); // Prefetch the first rmem from the first k-tile - copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{})); - copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{})); + copy(s2r_atom_a, tXsA_p(_,_,Int<0>{}), tXrA(_,_,Int<0>{})); + copy(s2r_atom_b, tXsB_p(_,_,Int<0>{}), tXrB(_,_,Int<0>{})); } // @@ -243,8 +266,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, if (k_block == K_BLOCK_MAX - 1) { // Slice the smem_pipe_read smem - tCsA_p = tCsA(_,_,_,smem_pipe_read); - tCsB_p = tCsB(_,_,_,smem_pipe_read); + tXsA_p = tXsA(_,_,_,smem_pipe_read); + tXsB_p = tXsB(_,_,_,smem_pipe_read); // Commit the smem for smem_pipe_read cp_async_wait(); @@ -253,8 +276,8 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, // Load A, B shmem->regs for k_block+1 auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static - copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next)); - copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next)); + copy(s2r_atom_a, tXsA_p(_,_,k_block_next), tXrA(_,_,k_block_next)); + copy(s2r_atom_b, tXsB_p(_,_,k_block_next), tXrB(_,_,k_block_next)); // Copy gmem to smem before computing gemm on each k-pipe if (k_block == 0) { @@ -268,8 +291,7 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, // Advance the smem pipe smem_pipe_write = smem_pipe_read; - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read; + smem_pipe_read = (smem_pipe_read == K_PIPE_MAX-1) ? 0 : smem_pipe_read+1; } // Thread-level register gemm for k_block gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); @@ -286,6 +308,126 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, axpby(alpha, tCrC, beta, tCgC); } +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + cute::half_t const* A, int ldA, + cute::half_t const* B, int ldB, + Beta beta, + cute::half_t * C, int ldC, + cudaStream_t stream = 0) +{ + assert(false && "Not implemented"); +} + +// Setup params for a TN HGEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + cute::half_t const* A, int ldA, + cute::half_t const* B, int ldB, + Beta beta, + cute::half_t * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + // Swizzles for LDSM and 128b k-major loads + auto swizzle_atom = composition(Swizzle<3,3,3>{}, + Layout>, + Stride<_8,Stride<_1,_64>>>{}); + + auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP)); + auto sC = make_layout(make_shape(bM, bN)); + + // Define the thread layouts (static) + + TiledCopy copyA = make_tiled_copy(Copy_Atom, cute::half_t>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 16x8 k-major + Layout>{}); // Val layout 1x8 k-major + TiledCopy copyB = make_tiled_copy(Copy_Atom, cute::half_t>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 16x8 k-major + Layout>{}); // Val layout 1x8 n-major + + TiledMMA mmaC = make_tiled_mma(SM80_16x8x8_F16F16F16F16_TN{}, + Layout>{}, // 2x2x1 MMA Atoms + Tile<_32,_32,_16>{}); // 32x32x16 Tiled MMA for LDSM + + //Copy_Atom s2r_atom_A; + //Copy_Atom, half_t> s2r_atom_A; + //Copy_Atom s2r_atom_A; + //Copy_Atom s2r_atom_A; + Copy_Atom s2r_atom_A; + + //Copy_Atom s2r_atom_B; + //Copy_Atom, half_t> s2r_atom_B; + //Copy_Atom s2r_atom_B; + //Copy_Atom s2r_atom_B; + Copy_Atom s2r_atom_B; + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + + auto kernel_fptr = gemm_device< + decltype(prob_shape), decltype(cta_tiler), + cute::half_t, decltype(dA), decltype(sA), decltype(copyA), decltype(s2r_atom_A), + cute::half_t, decltype(dB), decltype(sB), decltype(copyB), decltype(s2r_atom_B), + cute::half_t, decltype(dC), decltype(sC), decltype(mmaC), + decltype(alpha), decltype(beta)>; + + // Set L1 to be SMEM only + cudaFuncSetAttribute( + kernel_fptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + cudaFuncSetAttribute( + kernel_fptr, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + kernel_fptr<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, s2r_atom_A, + B, dB, sB, copyB, s2r_atom_B, + C, dC, sC, mmaC, + alpha, beta); +} + // Setup params for a NT GEMM template @@ -347,13 +489,14 @@ gemm_nt(int m, int n, int k, print_latex(mmaC); #endif + int smem_size = int(sizeof(SharedStorage)); dim3 dimBlock(size(mmaC)); dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); - gemm_device<<>> + gemm_device<<>> (prob_shape, cta_tiler, - A, dA, sA, copyA, - B, dB, sB, copyB, + A, dA, sA, copyA, AutoVectorizingCopy{}, + B, dB, sB, copyB, AutoVectorizingCopy{}, C, dC, sC, mmaC, alpha, beta); } @@ -423,13 +566,14 @@ gemm_tn(int m, int n, int k, print_latex(mmaC); #endif + int smem_size = int(sizeof(SharedStorage)); dim3 dimBlock(size(mmaC)); dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); - gemm_device<<>> + gemm_device<<>> (prob_shape, cta_tiler, - A, dA, sA, copyA, - B, dB, sB, copyB, + A, dA, sA, copyA, AutoVectorizingCopy{}, + B, dB, sB, copyB, AutoVectorizingCopy{}, C, dC, sC, mmaC, alpha, beta); } @@ -470,6 +614,11 @@ int main(int argc, char** argv) return 0; } + std::cout << "Using device 0: " << props.name + << " (SM" << props.major * 10 + props.minor + << ", " << props.multiProcessorCount + << ")" << std::endl; + int m = 5120; if (argc >= 2) sscanf(argv[1], "%d", &m); @@ -490,13 +639,13 @@ int main(int argc, char** argv) if (argc >= 6) sscanf(argv[5], "%c", &transB); - using TA = float; - using TB = float; - using TC = float; - using TI = float; + using TA = cute::half_t; + using TB = cute::half_t; + using TC = cute::half_t; + using TI = cute::half_t; - TI alpha = 1.0; - TI beta = 0.0; + TI alpha = static_cast(1.0f); + TI beta = static_cast(0.0f); std::cout << "M = " << m << std::endl; std::cout << "N = " << n << std::endl; diff --git a/examples/python/README.md b/examples/python/README.md index 590f2e24..0e69a409 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -20,3 +20,35 @@ * [04_epilogue_visitor](/examples/python/04_epilogue_visitor.ipynb) Shows how to fuse elementwise activation functions to GEMMs via the Python Epilogue Visitor interface + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index b97fc4c8..91589538 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -32,6 +32,11 @@ #include // CUTLASS_ARCH_MMA_SMxx_ENABLED +// MMA SM90A +#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + // TMA instructions #if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) # define CUTE_ARCH_TMA_SM90_ENABLED @@ -48,41 +53,59 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// - -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED # define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_TMEM_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TMA_SM100_ENABLED #endif // {add, mul, fma}.f32x2 PTX -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)) #define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif +#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) +# define CUTE_ARCH_MMA_SM120_ENABLED +# define CUTE_ARCH_TMA_SM120_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) +# define CUTE_ARCH_F8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED +# endif +#endif + diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index f5f50647..095cde5b 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -208,7 +208,7 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else @@ -221,18 +221,18 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else - if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else #endif - + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } } @@ -258,7 +258,6 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B; #endif - } } } diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index a4bc3794..ec156449 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -56,6 +56,15 @@ struct SM90_TMA_LOAD_1D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else asm volatile ( "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3}], [%2], %4;" @@ -63,6 +72,7 @@ struct SM90_TMA_LOAD_1D : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "l"(cache_hint) : "memory"); +#endif #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -102,6 +112,15 @@ struct SM90_TMA_LOAD_2D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else asm volatile ( "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4}], [%2], %5;" @@ -109,6 +128,7 @@ struct SM90_TMA_LOAD_2D : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "l"(cache_hint) : "memory"); +#endif #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -148,6 +168,15 @@ struct SM90_TMA_LOAD_3D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5}], [%2], %6;" @@ -155,6 +184,7 @@ struct SM90_TMA_LOAD_3D : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) : "memory"); +#endif #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -194,6 +224,15 @@ struct SM90_TMA_LOAD_4D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" @@ -201,6 +240,7 @@ struct SM90_TMA_LOAD_4D : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) : "memory"); +#endif #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -240,6 +280,15 @@ struct SM90_TMA_LOAD_5D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" @@ -247,6 +296,7 @@ struct SM90_TMA_LOAD_5D : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) : "memory"); +#endif #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -581,6 +631,9 @@ struct SM90_TMA_LOAD_MULTICAST_1D int32_t const& crd0) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -607,6 +660,9 @@ struct SM90_TMA_LOAD_MULTICAST_2D int32_t const& crd0, int32_t const& crd1) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -633,6 +689,9 @@ struct SM90_TMA_LOAD_MULTICAST_3D int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -659,6 +718,9 @@ struct SM90_TMA_LOAD_MULTICAST_4D int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -685,6 +747,9 @@ struct SM90_TMA_LOAD_MULTICAST_5D int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -757,6 +822,9 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D uint16_t const& offset_w) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -786,6 +854,9 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D uint16_t const& offset_w, uint16_t const& offset_h) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); @@ -815,6 +886,9 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); diff --git a/include/cute/arch/mma_sm100_desc.hpp b/include/cute/arch/mma_sm100_desc.hpp index 3d748061..f15108a4 100644 --- a/include/cute/arch/mma_sm100_desc.hpp +++ b/include/cute/arch/mma_sm100_desc.hpp @@ -552,7 +552,8 @@ make_runtime_instr_desc(UMMA::InstrDescriptor desc_i, uint16_t sparse_id2 = 0u, template + bool is_sparse = false + > CUTE_HOST_DEVICE constexpr UMMA::InstrDescriptorBlockScaled make_instr_desc_block_scaled() diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp index 1f74223b..4b6d7f86 100644 --- a/include/cute/arch/mma_sm100_umma.hpp +++ b/include/cute/arch/mma_sm100_umma.hpp @@ -28,9 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -// - -// #pragma once @@ -303,6 +300,92 @@ struct SM100_MMA_F16BF16_TS_SCALED } }; +template +struct SM100_MMA_TF32_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_TF32_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::tf32 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS_SPARSE without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::f16 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS_SPARSE without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + template @@ -551,6 +634,88 @@ struct SM100_MMA_F16BF16_2x1SM_TS_SCALED } }; +template +struct SM100_MMA_TF32_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS_SPARSE N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::tf32 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SPARSE N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::f16 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + template @@ -632,6 +797,47 @@ struct SM100_MMA_S8_TS } }; +template +struct SM100_MMA_S8_SS_SPARSE +{ + static_assert(is_same_v, "SM100_MMA_S8_SS_SPARSE result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_SS_SPARSE N-mode size should be 8 or a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::i8 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS_SPARSE without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + template @@ -714,10 +920,49 @@ struct SM100_MMA_S8_2x1SM_TS } }; +template +struct SM100_MMA_S8_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::i8 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + struct SM100_MMA_F8F6F4_SS { - - using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; @@ -876,6 +1121,91 @@ struct SM100_MMA_F8F6F4_2x1SM_TS } }; +template +struct SM100_MMA_F8F6F4_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; // %5, %6, %7, %8 + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::f8f6f4 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_SS_SPARSE +{ + static_assert(M == 128, "SM100_MMA_MXF8F6F4_SS_SPARSE M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + struct SM100_MMA_F8F6F4_2x1SM_SS { using DRegisters = void; @@ -910,6 +1240,47 @@ struct SM100_MMA_F8F6F4_2x1SM_SS } }; +template +struct SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE +{ + static_assert(M == 256, "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE M-mode size should be 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + template @@ -950,6 +1321,46 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS } }; +template +struct SM100_MMA_F8F6F4_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::f8f6f4 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; template +struct SM100_MMA_MXF4NVF4_SS_SPARSE +{ + static_assert(M == 128, "SM100_MMA_MXF4NVF4_SS_SPARSE M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4NVF4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_SS_SPARSE (VS = 32) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + + if constexpr (VS == 64) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_SS_SPARSE (VS = 64) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; template +struct SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE +{ + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 32) || (VS == 64), "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE Vector size can only be 32 or 64."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE (VS = 32) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + + if constexpr (VS == 64) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE (VS = 64) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; } // end namespace cute diff --git a/include/cute/arch/mma_sm120.hpp b/include/cute/arch/mma_sm120.hpp new file mode 100644 index 00000000..84c09b8b --- /dev/null +++ b/include/cute/arch/mma_sm120.hpp @@ -0,0 +1,3254 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include +#include // cute::float_e4m3_t, etc +#include + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x32_TN +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x32_TN for given data types."); +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 16x8x32 TN E3M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E2M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E4M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E5M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.F16 16x8x32 TN E2M1 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.F16 16x8x32 TN E4M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM120::BLOCKSCALED { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x32_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x32_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x64_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x64_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)), "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x64 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + static constexpr int SFBits = (VS == 32) ? 16 : 32; + using RegTypeSF = uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); + +#if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.SF 16x8x64 TN E2M1 x E2M1 with SF E4M3 +template +struct SM120_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + static constexpr int SFBits = (VS == 32) ? 16 : 32; + using RegTypeSF = uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM120::BLOCKSCALED + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class ElementB, + class ElementC +> +CUTE_HOST_DEVICE constexpr +auto +rr_op_selector_sm120() +{ + return SM120_16x8x32_TN{}; +} + +template < + class ElementA, + class ElementB, + class ElementC, + class ElementSF, + int SFVecSize, + bool UseF8F6F4 +> +CUTE_HOST_DEVICE constexpr +auto +rr_blockscaled_op_selector_sm120() +{ + if constexpr (UseF8F6F4) { + return SM120::BLOCKSCALED::SM120_16x8x32_TN_VS{}; + } + else{ + return SM120::BLOCKSCALED::SM120_16x8x64_TN_VS{}; + } +} + +} // namespace cute diff --git a/include/cute/arch/mma_sm120_sparse.hpp b/include/cute/arch/mma_sm120_sparse.hpp new file mode 100644 index 00000000..a6950032 --- /dev/null +++ b/include/cute/arch/mma_sm120_sparse.hpp @@ -0,0 +1,3444 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::float_e4m3_t, etc +#include +#include + +namespace cute { + +namespace SM120::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_SPARSE_16x8x64_TN +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x64_TN for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +} // end namespace SM120::SPARSE + +namespace SM120::BLOCKSCALED::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x64_TN_VS for given data types."); +}; + +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x128_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x128 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + static constexpr int SFBits = (VS == 64) ? 16 : 32; + using RegTypeSF = uint_bit_t; + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, RegTypeSF const& sfa, RegTypeSF const& sfb) + { + + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64 || VS == 32, "Scaling factor vector size has to be 64 or 32 for MXF4NVF4."); +#if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.scale_vec::2X.m16n8k128.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}, {%18, %19}," + "{%20}, {%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"(uint32_t(sfa)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::SPARSE::SM120_SPARSE_16x8x128_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x128 TN E2M1 x E2M1 with SF E4M3 +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + static constexpr int SFBits = (VS == 64) ? 16 : 32; + using RegTypeSF = uint_bit_t; + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, RegTypeSF const& sfa, RegTypeSF const& sfb) + { +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for NVF4 with e2m1 and scale factor e4m3."); + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.scale_vec::4X.m16n8k128.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}, {%18, %19}," + "{%20}, {%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"(uint32_t(sfa)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::SPARSE::SM120_SPARSE_16x8x128_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace SM120::BLOCKSCALED::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class ElementB, + class ElementC +> +CUTE_HOST_DEVICE constexpr +auto +rr_sparse_op_selector_sm120() +{ + // Get MMA SPARSE OP + return SM120::SPARSE::SM120_SPARSE_16x8x64_TN{}; +} + +template < + class ElementA, + class ElementB, + class ElementC, + class ElementSF, + int SFVecSize, + bool UseF8F6F4 +> +CUTE_HOST_DEVICE constexpr +auto +rr_blockscaled_sparse_op_selector_sm120() +{ + if constexpr (UseF8F6F4) { + return SM120::BLOCKSCALED::SPARSE::SM120_SPARSE_16x8x64_TN_VS{}; + } + else { + return SM120::BLOCKSCALED::SPARSE::SM120_SPARSE_16x8x128_TN_VS{}; + } +} + +} // namespace cute diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index e5eff988..5f65746b 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -31,15 +31,10 @@ #pragma once -#include +#include #include -// Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -# define CUTE_ARCH_MMA_SM90A_ENABLED -#endif - //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cute { diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index e342dbb2..beefa63f 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -626,7 +626,7 @@ make_tma_atom_im2col(CopyOp, auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); // Split according to the portion each multicast CTA will be responsible for - auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), num_multicast)); + auto tma_layout_vt = logical_divide(tma_layout_trunc, safe_div(size(tma_layout_trunc), num_multicast)); #if 0 print("glayout_basis : "); print(glayout_basis); print("\n"); @@ -748,7 +748,7 @@ make_tma_copy_im2col(CopyOp const& copy_op, // Scale that up to cover all of the smem_coords auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); // CTA T -> smem idx - auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + auto layout_t = make_layout(cosize(cta_t_map), safe_div(num_elems_per_tma, cosize(cta_t_map))); // CTA TID -> smem coord auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); // Combine with the T mapping diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 9c30ca53..ad668cee 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -1165,7 +1165,7 @@ make_tma_copy_tiled(CopyOp const& copy_op, // Scale that up to cover all of the smem_coords auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); // CTA T -> smem idx - auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + auto layout_t = make_layout(cosize(cta_t_map), safe_div(num_elems_per_tma, cosize(cta_t_map))); // CTA TID -> smem coord auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); // Combine with the T mapping @@ -1400,16 +1400,19 @@ tma_partition(Copy_Atom const& copy_atom, } // TMA Multicast Masks Calculation -template +template CUTE_HOST_DEVICE constexpr uint16_t create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, CtaCoord const& cta_coord_vmnk) { - auto cta_coord_slicer = replace(cta_coord_vmnk, _); - auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk); + auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_vmnk, cta_layout_vmnk); uint16_t mcast_mask = 0; + if constexpr (rank_v == 0) { + // Trivial case with no additional ctas + mcast_mask = uint16_t(1); + } else if constexpr (rank_v == 1 and depth_v <= 1 and not is_static::value) { // Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout @@ -1432,6 +1435,16 @@ create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, return mcast_mask; } +// Projections multicast mask +template +CUTE_HOST_DEVICE constexpr +uint16_t +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + return create_tma_multicast_mask(cta_layout_vmnk, replace(cta_coord_vmnk, _)); +} + //////////////////////////////////// // Make TMA copy A/B/C /////////////////////////////////// diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index fe2f3e0a..21b8d5d2 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -154,9 +154,10 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeA is a view (of the current tensor), forward the whole static_assert(is_same::value_type>::value - || (sizeof_bits_v::value_type> == 8 && (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + || (sizeof_bits_v::value_type> == 4 && + (sizeof_bits_v == 4 || sizeof_bits_v == 3 || sizeof_bits_v == 2)) , "Expecting ValTypeA type"); return make_tensor(static_cast(atensor)); } else { @@ -1117,4 +1118,7 @@ print_svg(TiledMMA const &mma) { #include #include #include +#include +#include + //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index ff7d5c55..f336eff2 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -29,8 +29,6 @@ * **************************************************************************************************/ - - #pragma once #include @@ -73,6 +71,41 @@ using Layout_MN_SW128_32B_Atom_Bits = ComposedLayout, smem_ptr_fl template using Layout_MN_SW128_32B_Atom = decltype(upcast::value>(Layout_MN_SW128_32B_Atom_Bits{})); +////////////////////////////////////////////////// +// Common layouts for Sparse UMMA Shared Memory // +////////////////////////////////////////////////// + +using cute::GMMA::Layout_MN_INTER_SpAtom; +using cute::GMMA::Layout_MN_SW32_SpAtom; +using cute::GMMA::Layout_MN_SW64_SpAtom; +using cute::GMMA::Layout_MN_SW128_SpAtom; +using cute::GMMA::Layout_K_INTER_SpAtom; +using cute::GMMA::Layout_K_SW32_SpAtom; +using cute::GMMA::Layout_K_SW64_SpAtom; +using cute::GMMA::Layout_K_SW128_SpAtom; + +template +using Layout_MN_SW128_32B_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_32B_Atom{}.layout_b()))>; + +// With UMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + // Tile a MN-logical layout atom to an MMA Tile Shape ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) template CUTE_HOST_DEVICE constexpr @@ -212,16 +245,11 @@ make_umma_desc(Tensor const& tensor) constexpr int SwizzleAtomKSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 4 : 8; - // Construct the canonical UMMA T Layout with shape - // ((SwizzleAtomMNSize,n),(SwizzleAtomKSize,2)) - Layout canonical_layout = - logical_divide(layout(u128_tensor), - make_tile(Layout, _1>{}, - Layout, _1>{})); + // Construct the canonical UMMA T Layout with shape ((SwizzleAtomMNSize,n),(SwizzleAtomKSize,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile>,Layout>>{}); - // Check ranks of canonical - CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_MN Layout: No flat offset mode"); - CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_MN Layout: No flat offset mode"); + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical UMMA_MN Layout: Expected profile failure."); // Check canonical mode strides constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? stride<0,0>(canonical_layout) : 1; @@ -253,11 +281,10 @@ make_umma_desc(Tensor const& tensor) "Not a canonical UMMA_K Layout: Expected MN-size multiple of 8."); // Construct the canonical UMMA N Layout with shape ((8,n),(2,1)) - Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,Layout<_2,_1>>{}); - // Check ranks of canonical - CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_K Layout: No flat offset mode"); - CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical UMMA_K Layout: No flat offset mode"); + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical UMMA_K Layout: Expected profile failure."); // Check canonical mode strides constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); constexpr uint32_t expected_stride_00 = SwizzleAtomMNSize; @@ -1396,6 +1423,182 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 4); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_SS_SPARSE supports 32bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_TF32_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 2); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_SPARSE supports 16bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_F16BF16_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 4); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_SS_SPARSE supports 32bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_TF32_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 2); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SPARSE supports 16bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_F16BF16_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + template @@ -1933,6 +2312,94 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_SS_SPARSE supports 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = 0; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_S8_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + template @@ -2062,6 +2529,94 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_SS_SPARSE supports 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_S8_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + template @@ -2171,7 +2726,7 @@ struct MMA_Traits,Int>>, Stride<_0,Stride< _1,Int>>>; using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] @@ -2222,6 +2777,107 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_SS_SPARSE supports types with leq 8bit types"); + + // Logical shape-K is always 512bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = 64; + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS_SPARSE; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF8F6F4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_F8F6F4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + template @@ -2427,6 +3171,92 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_F8F6F4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; template ,Int>>, Stride,Stride< _1,Int>>>; using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS 64 ? M/2 : M), (N == 192 ? 256 : N), a_major, b_major, + (M/2 > 64 ? M/2 : M), (round_up(N, 128)), a_major, b_major, a_neg, b_neg>; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] @@ -2519,7 +3349,106 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE supports types with leq 8bit types"); + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 64; + constexpr static int SFVecSize = 64; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS_SPARSE 64 ? M/2 : M), (round_up(N, 128)), a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; template ,Int>>, Stride<_0,Stride< _1,Int>>>; using MMA_ScaleFactor = SM100_MMA_MXF4_SS; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] @@ -2612,7 +3541,107 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4NVF4_SS_SPARSE supports 4bit types"); + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 128; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 64 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 32), + "2x mode (VectorSize=64) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4NVF4_SS_SPARSE; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF4NVF4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; template ,Int>>, Stride,Stride< _1,Int>>>; using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (N == 192 ? 256 : N), VS, a_major, b_major, + (M/2 > 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, a_neg, b_neg>; // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] @@ -2703,5 +3732,110 @@ struct MMA_Traits +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 128; + constexpr static int SFVecSize = VS; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeA = UMMA::sparse_smem_desc; + // using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 64 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 32), + "2x mode (VectorSize=64) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4NVF4_SS_SPARSE 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; } // end namespace cute diff --git a/include/cute/atom/mma_traits_sm120.hpp b/include/cute/atom/mma_traits_sm120.hpp new file mode 100644 index 00000000..e3399801 --- /dev/null +++ b/include/cute/atom/mma_traits_sm120.hpp @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +namespace SM120::BLOCKSCALED { + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B_zipped, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using RegTypeSFA = typename remove_extent::type; + using RegTypeSFB = typename remove_extent::type; + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + constexpr int RegNumSFA = extent::value; + constexpr int RegNumSFB = extent::value; + + auto [A, SFA] = unzip_tensor(A_zipped); + auto [B, SFB] = unzip_tensor(B_zipped); + + using Shape_MNK = typename MMA_Traits::Shape_MNK; + constexpr int SFVecSize = MMA_Traits::SFVecSize; + + // Assert logical size + CUTE_STATIC_ASSERT_V(size(SFA) == size<2>(Shape_MNK{})); + CUTE_STATIC_ASSERT_V(size(SFB) == size<2>(Shape_MNK{})); + + // Assert physical size + CUTE_STATIC_ASSERT(decltype(cosize(layout(SFA))){} == size<2>(Shape_MNK{}) / SFVecSize); + CUTE_STATIC_ASSERT(decltype(cosize(layout(SFB))){} == size<2>(Shape_MNK{}) / SFVecSize); + + Tensor rA = recast(A); + Tensor rB = recast(B); + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + Tensor rD = recast(D); + Tensor rC = recast(C); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + Tensor rSFA = recast(filter_zeros(SFA)); + Tensor rSFB = recast(filter_zeros(SFB)); + + CUTE_STATIC_ASSERT_V(size(rSFA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rSFB) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rSFA, make_int_sequence{}, + rSFB, make_int_sequence{}); +} +} // namespace SM120::BLOCKSCALED + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA F8F6F4 16x8x32 TN +template +struct MMA_Traits> + : MMA_Traits +{ + // The MMA accepts 8-bit inputs regardless of the types for A and B + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + + using ValTypeD = c_type; + using ValTypeC = c_type; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA MXF8F6F4 16x8x64 TN +template +struct MMA_Traits> +{ + // The MMA accepts 4-bit inputs regardless of the types for A and B + using ValTypeA = uint4_t; + using ValTypeB = uint4_t; + + using ValTypeD = c_type; + using ValTypeC = c_type; + + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + using Shape_MNK = Shape<_16,_8,_64>; + using ThrID = Layout<_32>; + + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V16) -> (M16,K64) + using BLayout = Layout,Shape <_8, _2>>, + Stride,Stride<_8,_256>>>; + // (T32,V64) -> (M16,K64) + using SFALayout = Layout,_64>, // Effectively 16 threads due to the 2:0 mode + Stride,_16>>; + // (T32,V64) -> (N8,K64) + using SFBLayout = Layout,_64>, // Effectively 8 threads due to the 4:0 mode + Stride, _8>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM80_16x8_Row; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA MXF8F6F4 16x8x32 TN +template +struct MMA_Traits> +{ + using UnderlyingTraits = MMA_Traits>; + + // The MMA accepts 8-bit inputs regardless of the types for A and B + using ValTypeA = typename UnderlyingTraits::ValTypeA; + using ValTypeB = typename UnderlyingTraits::ValTypeB; + + using ValTypeD = typename UnderlyingTraits::ValTypeD; + using ValTypeC = typename UnderlyingTraits::ValTypeC; + + using Shape_MNK = typename UnderlyingTraits::Shape_MNK; + using ThrID = typename UnderlyingTraits::ThrID; + + using ALayout = typename UnderlyingTraits::ALayout; + using BLayout = typename UnderlyingTraits::BLayout; + using CLayout = typename UnderlyingTraits::CLayout; + + // Scaling factor + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + // (T32,V32) -> (M16,K32) + using SFALayout = Layout,_32>, // Effectively 16 threads due to the 2:0 mode + Stride,_16>>; + // (T32,V32) -> (N8,K32) + using SFBLayout = Layout,_32>, // Effectively 8 threads due to the 4:0 mode + Stride, _8>>; +}; + +// Transform if needed +template +CUTLASS_DEVICE void +fp4_shift_A(MMA_Op const& op, Tensor&& tensor) { +} +template +CUTLASS_DEVICE void +fp4_shift_B(MMA_Op const& op, Tensor&& tensor) { +} + +// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix. +// ld.matrix b4x16_p64 places FP4 data at the first four bits in each +// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in +// the middle of the eight-bit container. Thus, e2m1 operands being fed +// to MMA F8F6F4 must be shifted left by two bits. +// 0b0000ABCD --> 0b00ABCD00 +// NOTE: Same transformation is NOT needed for FP6 and FP8. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_16x8x32_TN const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_16x8x32_TN const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +namespace SM120::BLOCKSCALED { + +// Template function with scale factor needs to enmuerate types one by one, as template +// arguments contatins two variadic lists, which cannot be deduced in one shot. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120::BLOCKSCALED::SM120_16x8x32_TN_VS const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120::BLOCKSCALED::SM120_16x8x32_TN_VS const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm120_sparse.hpp b/include/cute/atom/mma_traits_sm120_sparse.hpp new file mode 100644 index 00000000..b62c576d --- /dev/null +++ b/include/cute/atom/mma_traits_sm120_sparse.hpp @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V4) -> (M16,N8) +using SM120_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +namespace SM120::BLOCKSCALED::SPARSE +{ + +// Unpack explode/mma call with sparse and block scalaring inputs. +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + using SFARegisters = typename MMAOp::SFARegisters; + using SFBRegisters = typename MMAOp::SFBRegisters; + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using RegTypeSFA = typename remove_extent::type; + using RegTypeSFB = typename remove_extent::type; + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + constexpr int RegNumSFA = extent::value; + constexpr int RegNumSFB = extent::value; + + auto [tA, tSFA, tE] = unzip_tensor(A); + auto [tB, tSFB ] = unzip_tensor(B); + Tensor rA = recast(tA); + Tensor rE = recast(tE); + Tensor rB = recast(tB); + Tensor rD = recast(D); + Tensor rC = recast(C); + Tensor rSFA = recast(tSFA); + Tensor rSFB = recast(tSFB); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + CUTE_STATIC_ASSERT_V(size(filter_zeros(rSFA)) == Int{}); + CUTE_STATIC_ASSERT_V(size(filter_zeros(rSFB)) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + rSFA, make_int_sequence{}, + rSFB, make_int_sequence{}); +} + +} // end namespace SM120::BLOCKSCALED::SPARSE + + +namespace SM120::SPARSE +{ + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [tA, tE] = unzip_tensor(A); + Tensor rA = recast(tA); + Tensor rE = recast(tE); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}); +} + +} // end namespace SM120::SPARSE + +// sparse F8F6F4 without block-scaling +template +struct MMA_Traits> +{ + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using FrgTypeA = sparse_elem<2, uint8_t>; + using FrgTypeE = sparse_elem<8, uint8_t>; + + using ValTypeC = c_type; + using ValTypeD = c_type; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V16) -> (N8,K64) + using BLayout = Layout,Shape <_4, _4>>, + Stride,Stride<_8,_128>>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM120_16x8_Row; + + // (T32, V32) -> (M16, K64) + using ELayout = Layout, _32>, + Stride,_16>>; +}; + +// sparse MXF8F6F4 with block-scaling. +template +struct MMA_Traits> + : MMA_Traits> +{ + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using FrgTypeA = sparse_elem<2, uint8_t>; + using FrgTypeE = sparse_elem<8, uint8_t>; + + using ValTypeD = c_type; + using ValTypeC = c_type; + + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + using UnderlyingSFTraits = MMA_Traits>; + using SFALayout = typename UnderlyingSFTraits::SFALayout; + using SFBLayout = typename UnderlyingSFTraits::SFBLayout; +}; + +template +struct MMA_Traits> +{ + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = uint4_t; + using FrgTypeA = sparse_elem<4, uint8_t>; + using FrgTypeE = sparse_elem<16, uint8_t>; + + using ValTypeC = c_type; + using ValTypeD = c_type; + + using ValTypeSF = sf_type; + + constexpr static int SFVecSize = VS; + + using Shape_MNK = Shape<_16, _8, _128>; + using ThrID = Layout<_32>; + // (T32,V64) -> (M16,K128) + using ALayout = Layout,Shape <_16,_2, _2>>, + Stride,Stride<_16,_8,_1024>>>; + // (T32,V32) -> (N8,K128) + using BLayout = Layout,Shape <_8, _4>>, + Stride,Stride<_8,_256>>>; + // (T32,V128) -> (M16,K128) + using SFALayout = Layout,_128>, + Stride, _16>>; + // (T32,V128) -> (N8,K128) + using SFBLayout = Layout,_128>, + Stride, _8>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM120_16x8_Row; + // (T32, V64) -> (M16, K128) + using ELayout = Layout, Shape< _64>>, + Stride,Stride<_16>>>; +}; + +namespace SM120::SPARSE { + +// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix. +// ld.matrix b4x16_p64 places FP4 data at the first four bits in each +// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in +// the middle of the eight-bit container. Thus, e2m1 operands being fed +// to MMA F8F6F4 must be shifted left by two bits. +// 0b0000ABCD --> 0b00ABCD00 +// NOTE: Same transformation is NOT needed for FP6 and FP8. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_SPARSE_16x8x64_TN const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_SPARSE_16x8x64_TN const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} // end namespace SM120::SPARSE + +namespace SM120::BLOCKSCALED::SPARSE { + +// Template function with scale factor needs to enmuerate types one by one, as template +// arguments contatins two variadic lists, which cannot be deduced in one shot. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_SPARSE_16x8x64_TN_VS const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_SPARSE_16x8x64_TN_VS const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} // end namespace SM120::BLOCKSCALED::SPARSE + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index e3438f36..e688a7e6 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -239,11 +239,10 @@ make_gmma_desc(Tensor const& tensor) "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits for dense or (128|512)/sizeof_bits for sparse."); // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) - Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,_1>,Layout,_1>>{}); - // Check ranks of canonical - CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); - CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical GMMA_MN Layout: Expected profile failure."); // Check canonical mode strides constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; @@ -274,11 +273,10 @@ make_gmma_desc(Tensor const& tensor) "Not a canonical GMMA_K Layout: Expected K-size 2 for dense or 4 for sparse (in units of uint128_t)."); // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) - Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,Layout<_2,_1>>{}); - // Check ranks of canonical - CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); - CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical GMMA_K Layout: Expected profile failure."); // Check canonical mode strides constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); constexpr uint32_t expected_stride_00 = W; diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 557e1103..f9d70004 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -34,6 +34,7 @@ #include // cute::array #include // cute::is_tuple #include // cute::Int +#include // cute::seq #include // cute::transform /** IntTuple is an integer or a tuple of IntTuples. @@ -349,7 +350,6 @@ ceil_div(IntTupleA const& a, IntTupleB const& b) // // round_up // Round @a a up to the nearest multiple of @a b. -// For negative numbers, rounds away from zero. // template @@ -378,7 +378,7 @@ round_up(IntTupleA const& a, IntTupleB const& b) * Return shape_div(a, product(b)) * Case Int Int: * Enforce the divisibility condition a % b == 0 || b % a == 0 when possible - * Return a / b with rounding away from 0 (that is, 1 or -1 when a < b) + * Return ceil_div(a, b) */ template CUTE_HOST_DEVICE constexpr @@ -399,32 +399,19 @@ shape_div(IntTupleA const& a, IntTupleB const& b) } else if constexpr (is_tuple::value) { // int tuple return shape_div(a, product(b)); - } else - if constexpr (is_static::value && is_static::value) { - static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure"); - return C{}; - } else { // int int - //assert(a % b == 0 || b % a == 0); // Waive dynamic assertion - return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero - } - - CUTE_GCC_UNREACHABLE; -} - -/** Minimum for Shapes - */ -template -CUTE_HOST_DEVICE constexpr -auto -shape_min(IntTupleA const& a, IntTupleB const& b) -{ - if constexpr (is_tuple::value || is_tuple::value) { - static_assert(dependent_false, "Not implemented."); - } else - if constexpr (is_constant<1, IntTupleA>::value || is_constant<1, IntTupleB>::value) { - return Int<1>{}; // _1 is less than all other shapes, preserve static } else { - return cute::min(a, b); + // Strong divisibility condition + //static_assert((IntTupleA::value % IntTupleB::value == 0) or (IntTupleB::value % IntTupleA::value == 0), "Divisibility Condition"); + + // Weak divisibility condition + if constexpr (is_static::value and is_static::value) { + static_assert(((IntTupleA::value % IntTupleB::value) == 0) or ((IntTupleB::value % IntTupleA::value) == 0), "Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert((((a % b) == 0) or ((a % b) == 0)) && "Divisibility Condition"); + } + + return (a + b - Int<1>{}) / b; } CUTE_GCC_UNREACHABLE; @@ -572,6 +559,72 @@ filter_zeros(Tuple const& t) return filter_zeros(t, t); } +// +// Static sorting utilities in detail:: +// + +namespace detail { + +// Some compilers fail to constexpr evaluate quick_sort +// template +// constexpr cute::array quick_sort(cute::array a, int lo = 0, int hi = N-1) { +// if (hi <= lo) return; +// int p = lo; +// for (int i = lo; i < hi; ++i) { +// if (a[i] < a[hi]) { +// T tmp = a[p]; a[p] = a[i]; a[i] = tmp; +// ++p; +// } +// } +// T tmp = a[p]; a[p] = a[hi]; a[hi] = tmp; +// a = quick_sort(a, lo, p-1); +// a = quick_sort(a, p+1, hi); +// return a; +// } + +template +constexpr cute::array exchange_sort(cute::array a) { + for (size_t i = 0; i < N; ++i) { + for (size_t j = i+1; j < N; ++j) { + if (a[j] < a[i]) { + T tmp = a[j]; a[j] = a[i]; a[i] = tmp; + } + } + } + return a; +} + +template >> +struct Sort : Sort, to_seq_t> {}; + +template +struct Sort, seq> { + static_assert(sizeof...(Vs) == sizeof...(Is)); + static constexpr cute::array orig_array = {Vs...}; + static constexpr cute::array sort_array = exchange_sort(orig_array); + using type = seq; +}; + +struct kvpair { + int key, val; + constexpr bool operator<(kvpair const& o) const { return key < o.key; }; +}; + +template >> +struct SortByKey : SortByKey, to_seq_t, to_seq_t> {}; + +template +struct SortByKey, seq, seq> { + static_assert(sizeof...(Ks) == sizeof...(Vs)); + static_assert(sizeof...(Ks) == sizeof...(Is)); + static constexpr cute::array orig_array = {kvpair{Ks,Vs}...}; + static constexpr cute::array sort_array = exchange_sort(orig_array); + using key_type = seq; + using val_type = seq; +}; + +} // end namespace detail + // // Converters and constructors with arrays and params // diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index adf460bb..4ee901ad 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -627,27 +627,37 @@ depth(Layout const& layout) return depth(shape(layout)); } +// Return the coprofile of a mode as a tuple of _0s +// @post congruent(coprofile(@a layout), @a layout(i)) for any i +// @return T Tuple that is congruent with the codomain of @a a. +template +CUTE_HOST_DEVICE constexpr +auto +coprofile(Layout const& layout) +{ + return repeat_like(as_arithmetic_tuple(sum(stride(layout))), Int<0>{}); +} + // Return the codomain shape of a mode -// @post size(coshape(@a a)) == cosize(@a a) +// @post size(coshape(@a layout)) == cosize(@a layout) // @return C Coordinate with smallest elements such that -// @a elem_less(sub_layout(c), C) for all c < size(@a sub_layout) -// where sub_layout = get(layout). +// elem_less(@a sub_layout(c), C) for all c < size(@a sub_layout) +// where @a sub_layout = get(layout). template CUTE_HOST_DEVICE constexpr auto coshape(Layout const& layout) { - // Protect against negative strides - auto abs_sub_layout = make_layout(shape(layout), - transform_leaf(stride(layout), abs_fn{})); - auto co_coord = as_arithmetic_tuple(abs_sub_layout(size(abs_sub_layout) - Int<1>{})); - return co_coord + repeat_like(co_coord, Int<1>{}); + auto m1_shapes = transform_leaf( shape(layout), [](auto s) { return s - Int<1>{}; }); + auto abs_strides = transform_leaf(stride(layout), abs_fn{}); + auto co_coord = as_arithmetic_tuple(inner_product(m1_shapes, abs_strides)); + return transform_leaf(co_coord, [](auto c) { return c + Int<1>{}; }); } // Return the codomain size of a mode // @return M smallest integer such that -// @a sub_layout(c) < M for all c < size(@a sub_layout) -// where sub_layout = get(layout). +// size(@a sub_layout(c)) < M for all c < size(@a sub_layout) +// where @a sub_layout = get(layout). template CUTE_HOST_DEVICE constexpr auto @@ -1019,61 +1029,93 @@ auto composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, RShape const& rhs_shape, RStride const& rhs_stride) { - if constexpr (is_tuple::value) { - // Apply the right-distributivity of Layout composition + if constexpr (is_tuple::value) { // Right-distributivity of Layout composition for RHS tuple return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition_impl(lhs_shape, lhs_stride, s, d); }); } else - if constexpr (is_scaled_basis::value) { - // Special case for a ScaledBasis stride + if constexpr (is_scaled_basis::value) { // Special case for a RHS ScaledBasis stride return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride), rhs_shape, basis_value(rhs_stride)); } else - if constexpr (is_constant<0, RStride>::value) { - // Special case shortcut for any static stride-0 + if constexpr (is_constant<0, RStride>::value) { // Special case shortcut for any RHS static stride-0 return Layout{rhs_shape, rhs_stride}; } else - if constexpr (is_integral::value) { - // Special case shortcut for any integral LShape + if constexpr (is_integral::value) { // Special case shortcut for any LHS integral shape return Layout{rhs_shape, rhs_stride * lhs_stride}; - } else - if constexpr (is_constant<1, RStride>::value) { - // Special case shortcut for any static stride-1 - constexpr int R = rank_v; - auto result_shape_0 = take<0,R-1>(lhs_shape); + } else { // General case: LHS tuple, RHS integral + constexpr int R = tuple_size::value; - // Mod out the rhs_shape from the lhs_shape - auto [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), - [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); - }); + auto [result_shape, result_stride, rest_shape, rest_stride] = + cute::fold(make_seq{}, // t = [0,1,2,...,R-1) + cute::make_tuple(cute::tuple<>{}, // v = (result_shape, + cute::tuple<>{}, // result_stride, + rhs_shape, // rest_shape:Integral, + rhs_stride), // rest_stride:Integral) + [&](auto const& init, auto curr_i) { // f(v,t) -> v' + // Can ICE on some compilers + //auto [result_shape, result_stride, rest_shape, rest_stride] = init; + //auto [curr_shape, curr_stride] = curr; + // Unpack inputs + auto result_shape = get<0>(init); + auto result_stride = get<1>(init); + auto rest_shape = get<2>(init); + auto rest_stride = get<3>(init); - // Jump into coalesce and append (rest_shape, get(lhs_stride)) - return detail::bw_coalesce(result_shape_1, lhs_stride, rest_shape, get(lhs_stride)); - } else { - // General case: integral RShape and RStride, tuple LShape and LStride - constexpr int R = rank_v; - auto result_shape_0 = take<0,R-1>(lhs_shape); - auto result_stride_0 = take<0,R-1>(lhs_stride); + auto curr_shape = get(lhs_shape); + auto curr_stride = get(lhs_stride); - // Divide out the rhs_stride from the lhs_shape - auto [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), - [] (auto const& init, auto const& di) { - return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); - }); + // Strong divisibility condition -- requires composition to be statically verifiable. + //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); - // Apply any lhs_shape changes to the stride - auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); + // Weak divisibility condition -- verify the divisibility condition whenever possible + if constexpr (is_static::value and is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert((((rest_stride % curr_shape) == 0) or (rest_stride < curr_shape)) && "Stride Divisibility Condition"); + } - // Mod out the rhs_shape from the lhs_shape - auto [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), - [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); - }); + // next_shape: ceil(exclusive_prefix_product(lhs_shape) / rhs_stride) + [[maybe_unused]] auto next_shape = cute::ceil_div(curr_shape, abs(rest_stride)); + // next_stride: ceil(rhs_stride / exclusive_prefix_product(lhs_shape)) + [[maybe_unused]] auto next_stride = cute::ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride); - // Jump into coalesce and append (rest_shape, rest_stride * get(lhs_stride)) - return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(lhs_stride)); + if constexpr (is_constant<1, decltype(next_shape)>::value or is_constant<1, decltype(rest_shape)>::value) { + return cute::make_tuple(result_shape, + result_stride, + rest_shape, + next_stride); + } else { + auto new_shape = cute::min(next_shape, rest_shape); + + // Strong divisibility condition + //CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition"); + + // Weak divisibility condition + if constexpr (is_static::value and is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert(((rest_shape % new_shape) == 0) && "Shape Divisibility Condition"); + } + + return cute::make_tuple(append(result_shape, new_shape), + append(result_stride, rest_stride * curr_stride), + rest_shape / new_shape, + next_stride); + } + }); + + if constexpr (tuple_size::value == 0) { + return Layout{rest_shape, rest_stride * get(lhs_stride)}; + } else + if constexpr (is_constant<1, decltype(rest_shape)>::value) { + return Layout{unwrap(result_shape), unwrap(result_stride)}; + } else { + return Layout{append(result_shape, rest_shape), + append(result_stride, rest_stride * get(lhs_stride))}; + } } CUTE_GCC_UNREACHABLE; @@ -1088,8 +1130,7 @@ auto composition(Layout const& lhs, Layout const& rhs) { - auto coprofile = repeat_like(decltype(coshape(rhs)){}, Int<0>{}); - auto flat_lhs = detail::coalesce_x(lhs, coprofile); + auto flat_lhs = detail::coalesce_x(lhs, coprofile(rhs)); return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride()); } @@ -1203,37 +1244,6 @@ complement(Layout const& layout) // Right-Inverse and Left-Inverse // -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -inverse_seq(Shape const& shape, Stride const& stride, seq) -{ - auto next_I = cute::find_if(stride, [](auto a) { return is_constant{}; }); - - if constexpr (next_I == decltype(rank(stride))::value) { - // If not found, return current seq - return seq{}; - } else { - // auto next_stride = get(shape) * get(stride); - // NOTE: Needed for g++-7 - using next_stride = decltype(get(shape) * get(stride)); - - if constexpr (is_static::value && !is_constant::value) { - // If next_stride is static and unique, then continue - return inverse_seq(shape, stride, seq{}); - } else { - // Else return current seq + next_I - return seq{}; - } - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - // // Build the right-inverse of a layout // @pre is_static @@ -1248,22 +1258,40 @@ CUTE_HOST_DEVICE constexpr auto right_inverse(Layout const& layout) { - auto flat_layout = coalesce(layout); - auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); + // Flatten and filter shape-1 + auto clayout = coalesce(layout); + auto lstride = wrap(clayout.stride()); + auto lshape = wrap(clayout.shape()); - // Find Int<1>{}, the starting stride, and follow the strides to gen inverse_seq - [[maybe_unused]] auto iseq = detail::inverse_seq<1>(flat_layout.shape(), astride, seq<>{}); + // Prefix product of the shape + auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); }); - if constexpr (iseq.size() == 0) { - return Layout<_1,_0>{}; // Empty case, nothing found - } else { - // Generate the corresponding new strides and construct - auto rstride = compact_major(flat_layout.shape()); - return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), - unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); - } + // Filter out any dynamic strides + [[maybe_unused]] auto filtered_seq = filter_tuple(make_seq{}, lstride, [](auto i, auto d) { + return conditional_return>(cute::tuple{i}, cute::tuple<>{}); }); + [[maybe_unused]] auto filtered_stride = transform(filtered_seq, [&](auto i) { return get(lstride); }); - CUTE_GCC_UNREACHABLE; + // Sort by strides + using Sorted = detail::SortByKey; + auto sorted_seq = typename Sorted::val_type{}; + //auto sorted_stride = typename Sorted::key_type{}; + + auto [result_shape, result_stride, curr] = cute::fold(sorted_seq, tuple,tuple<_0>,_1>{}, + [&](auto const& init, auto i) { + [[maybe_unused]] auto ishape = get(lshape); + [[maybe_unused]] auto istride = get(lstride); + [[maybe_unused]] auto curr_stride = get<2>(init); + + if constexpr (is_constant::value) { + return make_tuple(append(get<0>(init), ishape), // result_shape + append(get<1>(init), get(preprod_shape)), // result_stride + ishape * istride); + } else { + return init; + } + }); + + return coalesce(make_layout(result_shape, result_stride)); } CUTE_HOST_DEVICE constexpr @@ -1274,13 +1302,12 @@ right_inverse(Underscore const& _) } // -// Build the left-inverse of a layout +// Build the quasi-inverse of a layout (left-inverse when layout is injective) // @pre is_static -// @pre @a layout is an injective function // @result A layout @a result such that -// @a result(@a layout(i)) == i for all i < size(@a layout) +// @a layout(@a result(@a layout(i))) == @a layout(i) for all i < size(@a layout) // @result A layout @a result such that -// composition(@a result, @a layout) is identical to make_layout(shape(layout)) +// composition(@layout, composition(@a result, @a layout)) is identical to @a layout // template @@ -1288,7 +1315,39 @@ CUTE_HOST_DEVICE constexpr auto left_inverse(Layout const& layout) { - return right_inverse(make_layout(layout, complement(layout))); + // Flatten and filter shape-1 + auto clayout = coalesce(layout); + auto lstride = wrap(clayout.stride()); + auto lshape = wrap(clayout.shape()); + + // Prefix product of the shape + auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); }); + + // Sort by strides + static_assert(is_static::value, "Left inverse requires static strides."); + using Sorted = detail::SortByKey>; + auto sorted_seq = typename Sorted::val_type{}; + //auto sorted_stride = typename Sorted::key_type{}; + + auto [result_shape, result_stride] = cute::fold(sorted_seq, tuple,tuple<_0>>{}, + [&](auto const& init, auto i) { + [[maybe_unused]] auto istride = get(lstride); + + if constexpr (is_constant<0, decltype(istride)>::value) { + return init; + } else { + auto result_shape = get<0>(init); + auto result_stride = get<1>(init); + + CUTE_STATIC_ASSERT_V((istride % size(result_shape)) == Int<0>{}, "Left inverse divisibility condition"); + + return make_tuple(append(result_shape, istride / size(result_shape)), + append(result_stride, get(preprod_shape))); + } + }); + + return coalesce(make_layout(append(result_shape, get(lshape)), + result_stride)); } CUTE_HOST_DEVICE constexpr @@ -1506,7 +1565,7 @@ auto logical_divide(Layout const& layout, Layout const& tiler) { - return composition(layout, make_layout(tiler, complement(tiler, shape(layout)))); + return composition(layout, make_layout(tiler, complement(tiler, shape(coalesce(layout))))); } template @@ -1760,10 +1819,11 @@ upcast(Shape const& shape, Stride const& stride) } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride return Layout{shape,stride}; } else if constexpr (is_static::value) { // static stride - return make_layout(shape_div(shape, shape_div(Int{}, abs(stride))), - shape_div(stride, Int{})); + static_assert(Stride::value % N == 0 or N % Stride::value == 0, "Divisibility condition"); + return make_layout(ceil_div(shape, ceil_div(Int{}, abs(stride))), + signum(stride) * ceil_div(abs(stride), Int{})); } else { // dynamic stride - // assume dynamic strides are larger than N and divisible + // Assume dynamic strides are larger than N and divisible // assert(stride % N == 0); return make_layout(shape, safe_div(stride, Int{})); } diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index fc26fbb3..6a967783 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -37,7 +37,7 @@ /* This implements a ComposedLayout of the form * LayoutA o Offset o LayoutB * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. - * For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). + * For example, when the "divisibility condition" is violated in composition(LayoutA, LayoutB). * * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 3c2c23cc..33076378 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -370,12 +370,21 @@ safe_div(ScaledBasis const& b, U const& u) template CUTE_HOST_DEVICE constexpr auto -shape_div(ScaledBasis const& b, U const& u) +ceil_div(ScaledBasis const& b, U const& u) { - auto t = shape_div(b.value(), u); + auto t = ceil_div(b.value(), u); return ScaledBasis{t}; } +template +CUTE_HOST_DEVICE constexpr +auto +abs(ScaledBasis const& e) +{ + auto t = abs(e.value()); + return ScaledBasis{t}; +} + // Equality template CUTE_HOST_DEVICE constexpr @@ -399,14 +408,6 @@ operator==(T const&, ScaledBasis const&) { return {}; } -// Abs -template -CUTE_HOST_DEVICE constexpr -auto -abs(ScaledBasis const& e) { - return ScaledBasis{abs(e.value())}; -} - // Multiplication template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/numeric/integer_sequence.hpp b/include/cute/numeric/integer_sequence.hpp index 4118d9cb..799e1896 100644 --- a/include/cute/numeric/integer_sequence.hpp +++ b/include/cute/numeric/integer_sequence.hpp @@ -124,6 +124,31 @@ using tuple_seq = make_seq>::value>; template using tuple_rseq = make_rseq>::value>; +// +// Convert a parameter pack to an int sequence +// + +template +struct to_seq; + +template <> +struct to_seq> { + using type = seq<>; +}; + +template +struct to_seq> { + using type = seq; +}; + +template