diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ca22eaf..00728725 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,16 @@ - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). + - [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu). + - [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu). +* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu). +* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture. * Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. - Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. diff --git a/README.md b/README.md index 4593e94d..433c375c 100644 --- a/README.md +++ b/README.md @@ -50,11 +50,16 @@ architecture. - [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu). - [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu). - [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu). + - [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu). + - [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu). +* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/77_blackwell_mla.cu). +* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture. * Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM. * Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures: - Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture. - Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture. - Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture. + - Support for [mixed-dtype grouped GEMM with groupwise scaling](./examples/69_hopper_mixed_dtype_grouped_gemm) for Hopper architecture. - Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture. - Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture. - Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture. diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu index 6fdcc836..c9fbd756 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -402,7 +402,7 @@ struct Options : MixedDtypeOptions{ void initialize(Options const& options) { auto shape_B = cute::make_shape(options.n, options.k, options.l); - int const scale_k = (options.k + options.g - 1) / options.g; + int const scale_k = cutlass::ceil_div(options.k, options.g); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); // Reverse stride here due to swap and transpose @@ -429,7 +429,7 @@ void initialize(Options const& options) { block_zero.reset(scale_k * options.l * options.n); initialize_tensor(block_A, seed + 2022); - initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_B, seed + 2021); initialize_tensor(block_C, seed + 2020); initialize_scale(block_scale, options); initialize_zero(block_zero, options); diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index cc540803..dcab4a7a 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -318,7 +318,7 @@ struct Options : MixedDtypeOptions { void initialize(Options const& options) { auto shape_B = cute::make_shape(options.n, options.k, options.l); - int const scale_k = (options.k + options.g - 1) / options.g; + int const scale_k = cutlass::ceil_div(options.k, options.g); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); // Reverse stride here due to swap and transpose @@ -347,7 +347,7 @@ void initialize(Options const& options) { block_zero.reset(scale_k * options.l * options.n); initialize_tensor(block_A, seed + 2022); - initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_B, seed + 2021); cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); initialize_tensor(block_C, seed + 2020); initialize_scale(block_scale, options); diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index aa114e74..15eb4692 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -288,7 +288,7 @@ cutlass::DeviceAllocation -bool initialize_quant_tensor( - cutlass::DeviceAllocation& block, - uint64_t seed = 2023) { - - float scope_min = float(cutlass::platform::numeric_limits::lowest()); - float scope_max = float(cutlass::platform::numeric_limits::max()); - - cutlass::reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); - - return true; -} - template bool initialize_scale( cutlass::DeviceAllocation& block, @@ -232,10 +218,8 @@ bool initialize_scale( float scope_max = 1.0f, scope_min = 1.0f; if (options.mode != MixedDtypeGemmMode::ConvertOnly) { float elt_max_f = float(cutlass::platform::numeric_limits::max()); - const float max_dequant_val = 4.f; - const float min_dequant_val = 0.5f; - scope_max = max_dequant_val / elt_max_f; - scope_min = min_dequant_val / elt_max_f; + scope_max = 2.f; + scope_min = 0.1f; } cutlass::reference::device::BlockFillRandomUniform( block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); diff --git a/examples/65_distributed_gemm/65_distributed_gemm.cu b/examples/65_distributed_gemm/65_distributed_gemm.cu index 90b6ff8b..6509609f 100644 --- a/examples/65_distributed_gemm/65_distributed_gemm.cu +++ b/examples/65_distributed_gemm/65_distributed_gemm.cu @@ -120,8 +120,7 @@ #include "helper.h" // Distributed GEMM helpers -#include "util/benchmark.h" -#include "util/device_copy.h" +#include "dist_gemm_helpers.h" using namespace cute; diff --git a/examples/65_distributed_gemm/util/device_copy.h b/examples/65_distributed_gemm/util/device_copy.h deleted file mode 100644 index 257800a0..00000000 --- a/examples/65_distributed_gemm/util/device_copy.h +++ /dev/null @@ -1,84 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -/*! \file - \brief generic device-to-device data movement kernel based for CuTe tensors. - - NOTE: this kernel assigns one element copy to every thread, and is by no means - an efficient way of copying tensors. It should only be used for convenience in - reference checks. - -*/ - -#pragma once - -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/cuda_host_adapter.hpp" - -namespace cutlass { - -template -void device_copy(TensorSource tensor_source, - TensorDestination tensor_destination, - cudaStream_t stream); - - -template -__global__ void device_copy_kernel(TensorSource const tensor_source, - TensorDestination tensor_destination) { - auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x; - using ElementSrc = typename TensorSource::value_type; - using ElementDst = typename TensorDestination::value_type; - NumericConverter converter; - if (linear_idx < size(tensor_source)) { - tensor_destination(linear_idx) = converter(tensor_source(linear_idx)); - } -} - -template -void device_copy(TensorSource tensor_source, - TensorDestination tensor_destination, - cudaStream_t stream) { - - assert(tensor_source.size() == tensor_destination.size()); - - auto numel = tensor_source.size(); - static constexpr int NumThreads = 128; - auto grid_size = cute::ceil_div(numel, NumThreads); - - dim3 grid(grid_size); - dim3 block(NumThreads); - device_copy_kernel<<>>(tensor_source, tensor_destination); -} - -} //namespace cutlass diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu index c1978c32..9b56697b 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu @@ -374,7 +374,7 @@ void allocate(Options const& options) { auto N = get<1>(problem); auto K = get<2>(problem); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); @@ -510,7 +510,7 @@ void initialize(Options &options) { beta_device.copy_from_host(ptr_beta_host.data()); initialize_tensor(block_A, seed + 2023); - initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_B, seed + 2022); initialize_tensor(block_C, seed + 2021); initialize_scale(block_scale, options); initialize_zero(block_zero, options); @@ -519,13 +519,13 @@ void initialize(Options &options) { for (int32_t i = 0; i < options.groups; ++i) { - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{}); auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{}); auto layout_B = make_layout(shape_B, stride_B_host.at(i)); auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i)); cudaStream_t stream = cudaStreamDefault; - cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream); + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream); } problem_sizes.reset(options.groups); @@ -619,7 +619,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro arguments = Args { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; @@ -676,6 +676,7 @@ bool verify(Options const& options) { for (int32_t i = 0; i < options.groups; ++i) { auto problem = options.problem_sizes_host.at(i); + // we don't swap and transpose in the verify so revert the problem shape. auto N = get<0>(problem); auto M = get<1>(problem); auto K = get<2>(problem); @@ -712,7 +713,7 @@ bool verify(Options const& options) { CUDA_CHECK(cudaDeviceSynchronize()); passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); - std::cout << "Group: " << i << " Status: " << passed << std::endl; + std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl; } } return passed; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu index 07ff66b3..8407cdad 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu @@ -341,7 +341,7 @@ void allocate(Options const& options) { auto N = get<1>(problem); auto K = get<2>(problem); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); @@ -479,7 +479,7 @@ void initialize(Options& options) { beta_device.copy_from_host(ptr_beta_host.data()); initialize_tensor(block_A, seed + 2023); - initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_B, seed + 2022); cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); initialize_tensor(block_C, seed + 2021); initialize_scale(block_scale, options); @@ -565,7 +565,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro arguments = Args { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; @@ -617,6 +617,7 @@ bool verify(Options const& options) { for (int32_t i = 0; i < options.groups; ++i) { auto problem = options.problem_sizes_host.at(i); + // we don't swap and transpose in the verify so revert the problem shape. auto N = get<0>(problem); auto M = get<1>(problem); auto K = get<2>(problem); @@ -630,11 +631,11 @@ bool verify(Options const& options) { stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i)); auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i)); cudaStream_t stream = cudaStreamDefault; - cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream); + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream); // // Compute reference output @@ -659,7 +660,7 @@ bool verify(Options const& options) { CUDA_CHECK(cudaDeviceSynchronize()); passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); - std::cout << "Group: " << i << " Status: " << passed << std::endl; + std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl; } } return passed; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu index ffeb233e..41cccfbb 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu @@ -282,7 +282,7 @@ void allocate(Options const& options) { auto N = get<1>(problem); auto K = get<2>(problem); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); offset_A.push_back(total_elements_A); offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); @@ -418,7 +418,7 @@ void initialize(Options &options) { beta_device.copy_from_host(ptr_beta_host.data()); initialize_tensor(block_A, seed + 2023); - initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_B, seed + 2022); initialize_tensor(block_C, seed + 2021); initialize_scale(block_scale, options); initialize_zero(block_zero, options); @@ -485,7 +485,7 @@ typename Gemm::Arguments args_from_options(Options const& options, bool host_pro arguments = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, - {ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k}, + {ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; @@ -542,6 +542,7 @@ bool verify(Options const& options) { for (int32_t i = 0; i < options.groups; ++i) { auto problem = options.problem_sizes_host.at(i); + // we don't swap and transpose in the verify so revert the problem shape. auto N = get<0>(problem); auto M = get<1>(problem); auto K = get<2>(problem); @@ -555,11 +556,11 @@ bool verify(Options const& options) { stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); - const int scale_k = 1; + int const scale_k = cutlass::ceil_div(options.k, options.c); auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i)); auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i)); cudaStream_t stream = cudaStreamDefault; - cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream); + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream); // // Compute reference output @@ -584,7 +585,7 @@ bool verify(Options const& options) { CUDA_CHECK(cudaDeviceSynchronize()); passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); - std::cout << "Group: " << i << " Status: " << passed << std::endl; + std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl; } } return passed; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt b/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt index 4c21cd48..f32c5d52 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt @@ -50,6 +50,7 @@ set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10) set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling +set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling cutlass_example_add_executable( 69_hopper_mixed_dtype_grouped_gemm @@ -69,6 +70,7 @@ cutlass_example_add_executable( TEST_RANDOM_PERF_LARGE_GROUP TEST_DIRECT_BATCHED TEST_SCALE_PERCOL + TEST_SCALE_GROUP ) cutlass_example_add_executable( @@ -89,6 +91,7 @@ cutlass_example_add_executable( TEST_RANDOM_PERF_LARGE_GROUP TEST_DIRECT_BATCHED TEST_SCALE_PERCOL + TEST_SCALE_GROUP ) cutlass_example_add_executable( @@ -109,4 +112,5 @@ cutlass_example_add_executable( TEST_RANDOM_PERF_LARGE_GROUP TEST_DIRECT_BATCHED TEST_SCALE_PERCOL + TEST_SCALE_GROUP ) diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md index f4d71ea3..10b57aa0 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md @@ -7,11 +7,11 @@ This example shows how to perform Grouped GEMMs on Hopper when A and B have diff - in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B. - if scales and zero-points are included, also pass the array of their strides in the arguments. -Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs. +Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling. ## Upcoming features -Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed. +Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed. ## Copyright diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp b/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp index db391cce..8568b467 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp @@ -58,6 +58,7 @@ public: void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("benchmark", benchmark_path); cmd.get_cmd_line_argument("c", c); MixedDtypeOptions::parse(argc, args); @@ -71,6 +72,7 @@ public: << " --m= Sets the M extent of the GEMM for all groups\n" << " --n= Sets the N extent of the GEMM for all groups\n" << " --k= Sets the K extent of the GEMM for all groups\n" + << " --c= Sets the chunk size for scaling the quantized weights\n" << " --groups= Sets the number of individual GEMM problems\n" << " --mode= The mode to run the gemm\n" << " --alpha= Epilogue scalar alpha\n" @@ -183,11 +185,6 @@ void grouped_mixed_dtype_profiling( result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); - - std::cout << " Problem Sizes, Alpha, Beta\n"; - for (int32_t i = 0; i < options.groups; ++i) { - std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n'; - } std::cout << " Groups : " << options.groups << '\n' << " Avg runtime : " << result.avg_runtime_ms << " ms\n" << " GFLOPS : " << result.gflops << '\n'; diff --git a/examples/77_blackwell_fmha/77_blackwell_mla.cu b/examples/77_blackwell_fmha/77_blackwell_mla.cu new file mode 100644 index 00000000..baa70fce --- /dev/null +++ b/examples/77_blackwell_fmha/77_blackwell_mla.cu @@ -0,0 +1,832 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the + NVIDIA Blackwell Architecture. +*/ + +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "reference/fmha_mla_reference.hpp" +#include "reference/reference_abs_error.hpp" + +#include "device/sm100_mla.hpp" +#include "kernel/sm100_mla_tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; +using namespace cutlass::fmha::kernel; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class InitStyle { + kOne, kLinearStride128, kLinearStride1, kRandom, kNone +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help = false; + bool error = false; + + int b = 1; + int k = 256; + int split_kv = -1; // number of split along k dim. + bool is_var_split_kv = false; + int max_split_kv = 16; + int page = -1; + float spread = 0.2f; + int iterations = 3; + bool verify = false; + bool verbose = false; + + int sm_count = 0; + + std::string kernel_filter; + + InitStyle init_style_q = InitStyle::kRandom; + InitStyle init_style_c = InitStyle::kRandom; + + static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) { + std::string s; + cmd.get_cmd_line_argument(name, s, s); + if (s.empty()) { + dst = src; + } + else { + if (s == "r") { + dst = InitStyle::kRandom; + } + else if (s == "1") { + dst = InitStyle::kOne; + } + else if (s == "d") { + dst = InitStyle::kLinearStride1; + } + else if (s == "s") { + dst = InitStyle::kLinearStride128; + } + else if (s == "n") { + dst = InitStyle::kNone; + } + else { + std::cout << "Error: " << s << " is not a valid input type.\n"; + std::exit(-1); + } + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("k", k, -1); + if (k == -1) k = defaults.k; + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv); + cmd.get_cmd_line_argument("page", page, defaults.page); + cmd.get_cmd_line_argument("spread", spread, defaults.spread); + cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false); + if (page == -1) { + is_var_split_kv = false; + } + cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv); + if (is_var_split_kv == true) { + split_kv = max_split_kv; + } + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); + + get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); + get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c); + get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q); + get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c); + + cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "77_blackwell_mla\n\n" + << " This example showcases the use of CUTLASS for fused multi-head latent\n" + << " attention kernels targeting NVIDIA's Blackwell architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --k= Sets the K extent\n" + << " --page= Enables paging and sets the page size\n" + << " --iterations= Benchmarking iterations\n" + << " --spread= Relative spread away from K for paging\n" + << " --split_kv= Split KV factor\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --sm-count Sets SM count rather than querying it\n" + << " --kernel-filter= Sets regexp to match kernel against\n" + << "\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_block( + DeviceAllocation& block, + uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) { + + switch (init_style) { + case InitStyle::kOne: { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + break; + } + case InitStyle::kRandom: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) -1, (Element) 1); + break; + } + case InitStyle::kLinearStride1: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 128; i ++) { + for (int j = 0; j < 128; j++) { + data[j + 128*i] = static_cast((double) (j % 4)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kLinearStride128: { + std::vector data(block.size()); + for (size_t i = 0; i < block.size() / 64; i ++) { + for (int j = 0; j < 64; j++) { + data[j + 64*i] = static_cast((double) (i % 9)); + } + } + block.copy_from_host(data.data(), data.size()); + break; + } + case InitStyle::kNone: { + break; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_tc_s = 0; + double tbytes_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IsPersistent { + static const bool value = v; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class TileShape, + class PersistenceOption = IsPersistent +> +struct Runner { + +#ifdef FP8 + using Element = cutlass::float_e4m3_t; +#elif FP16 + using Element = cutlass::half_t; +#else + #error "Must either define FP8 or FP16" +#endif + + using ElementAcc = float; + using ElementOut = cutlass::half_t; + + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = std::conditional_t< + PersistenceOption::value, + Sm100MlaPersistentTileScheduler, + Sm100MlaIndividualTileScheduler + >; + + using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler + >; + using Operation = cutlass::fmha::device::MLA; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q_latent; + StrideK stride_C_latent; + StrideQ stride_Q_rope; + StrideK stride_K_rope; + StrideO stride_O; + StrideLSE stride_LSE; + StrideLSE stride_PT; + + uint64_t seed = 0; + + int page_size = -1; + int page_count = -1; + + // We allocate Q and C as first latent, then rope + // This means that we offset the pointer by HeadDim_latent to get the rope + // portion + DeviceAllocation block_Q; + DeviceAllocation block_C; + DeviceAllocation block_O; + DeviceAllocation block_seq; + DeviceAllocation block_PT; + DeviceAllocation block_split_kv; + DeviceAllocation block_accum_split_len; + DeviceAllocation block_LSE; + DeviceAllocation block_ref_O; + DeviceAllocation block_ref_LSE; + + ElementAcc scale; + + // + // Methods + // + + bool verify(const ProblemShape& problem_shape) { + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int page_K = K; + int page_B = B; + if (block_PT.get() != nullptr) { + page_K = page_size; + page_B = page_count; + } + + Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()), + cute::make_tuple(H, D_latent, B), + stride_Q_latent); + + Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent), + cute::make_tuple(H, D_rope, B), + stride_Q_rope); + + Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()), + cute::make_tuple(page_K, D_latent, page_B), + stride_C_latent); + + Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent), + cute::make_tuple(page_K, D_rope, page_B), + stride_K_rope); + + Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), + cute::make_tuple(H, D_latent, B), + stride_O); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), + cute::make_tuple(H, B), + stride_LSE); + + Tensor mSeq = make_tensor(make_gmem_ptr(static_cast(block_seq.get())), make_shape(B)); + Tensor mPT = make_tensor(make_gmem_ptr(static_cast(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT); + + fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; +#ifdef B2B + reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff); +#else + reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); +#endif + + bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_O) { + std::cerr << "failed O: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + bool passed_LSE = true; +#ifndef B2B + reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); + + passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if ( ! passed_LSE) { + std::cerr << "failed LSE: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } +#endif + + return passed_O && passed_LSE; + } + + ProblemShape initialize(const Options& options) { + auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + // the scale is based on the non-absorbed sizes, change as appropriate + // we can't determine this parameter from the info we have, it's an input + int D_non_latent = 128; + scale = static_cast(1.0 / sqrt(1.0 * (D_non_latent + D_rope))); + // Shape (H, D, B) + stride_Q_latent = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, static_cast(H * (0 + D_latent + D_rope))); + stride_Q_rope = stride_Q_latent; + stride_O = cute::make_tuple(static_cast(0 + D_latent), _1{}, static_cast(0 + H * D_latent)); + stride_LSE = cute::make_tuple(_1{}, 0 + H); + + block_Q.reset(static_cast(options.b) * H * (D_latent + D_rope)); + block_O.reset(static_cast(options.b) * H * D_latent); + block_LSE.reset(static_cast(options.b) * H); + block_ref_O.reset(static_cast(options.b) * H * D_latent); + block_ref_LSE.reset(static_cast(options.b) * H); + + if (options.page == -1) { + + stride_C_latent = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, static_cast(options.k) * (D_latent + D_rope)); + stride_K_rope = stride_C_latent; + + block_C.reset(static_cast(options.b) * options.k * (D_latent + D_rope)); + + } + else { + + float spread = options.spread; + int max_K = static_cast((1 + spread) * K); + int min_K = static_cast((1 - spread) * K); + page_size = options.page; + page_count = B * ceil_div(max_K, page_size); + stride_PT = cute::make_stride(_1{}, page_count); + + std::vector host_seq(B); + std::vector host_PT(page_count * B); + + for (int i = 0; i < B; i++) { + int seq = min_K + rand() % (max_K - min_K + 1); + host_seq[i] = seq; + for (int j = 0; j < ceil_div(seq, page_size); j++) { + host_PT[page_count * i + j] = i + j * B; + } + } + + block_seq.reset(host_seq.size()); + block_seq.copy_from_host(host_seq.data(), host_seq.size()); + block_PT.reset(host_PT.size()); + block_PT.copy_from_host(host_PT.data(), host_PT.size()); + + get<1>(problem_shape) = max_K; + + stride_C_latent = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, page_size * static_cast((D_latent + D_rope))); + stride_K_rope = stride_C_latent; + + block_C.reset(page_count * page_size * static_cast((D_latent + D_rope))); + + if (options.is_var_split_kv == true) { + std::vector host_split_kv(B); + for(int i = 0; i < B; ++i) { + auto len = host_seq[i]; + int split = ceil_div(options.max_split_kv, ceil_div(max_K, len)); + host_split_kv[i] = split; + } + block_split_kv.reset(B); + block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size()); + } + } + + initialize_block(block_Q, seed + 2023, options.init_style_q); + initialize_block(block_C, seed + 2022, options.init_style_c); + + return problem_shape; + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + + ProblemShape problem_shape = initialize(options); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + typename Operation::Arguments arguments{ + problem_shape, + { scale, + block_Q.get(), stride_Q_latent, + block_Q.get() + D_latent, stride_Q_rope, + block_C.get(), stride_C_latent, + block_C.get() + D_latent, stride_K_rope, + block_seq.get(), + block_PT.get(), stride_PT, + page_count, page_size}, + { block_O.get(), + stride_O, + block_LSE.get(), + stride_LSE}, + hw_info, + options.split_kv, + options.is_var_split_kv ? block_split_kv.get() : nullptr + }; + if (options.split_kv < 0 && !options.is_var_split_kv) { + Operation::set_split_kv(arguments); + } + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + + cutlass::Status status = cutlass::Status::kSuccess; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops = 1.0; + flops *= B; + flops *= K; + flops *= H; + flops *= 2.0; + flops *= (2.0 * D_latent + D_rope); + + double bytes_q = sizeof(Element); + bytes_q *= B; + bytes_q *= H; + bytes_q *= (D_latent + D_rope); + double bytes_c = sizeof(Element); + bytes_c *= B; + bytes_c *= options.k; // K may be max_K here + bytes_c *= (D_latent + D_rope); + double bytes_o = sizeof(ElementOut); + bytes_o *= B; + bytes_o *= H; + bytes_o *= D_latent; + double bytes = bytes_q + bytes_c + bytes_o; + + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_tc_s = tflops_s; + example_result.tbytes_s = tbytes_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_shape); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms * 1e3 << " us, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, const char* name, auto... kernel_options) { + if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) { + return; + } + Runner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using NumHeads = _128; + using HeadDimLatent = _512; + using HeadDim = Shape; + + std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " "; + std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " "; + std::cout << "Q 1 K " << options.k << " Gen None "; + std::cout << "Split " << options.split_kv << " Gen None "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + using Blocking = _128; + std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{}); + std::string individual = " individual"; + std::string persistent = " persistent"; +#if FP8 + name += " fp8"; + // Persistent Tile Scheduler + run(Shape{}, (name + persistent).c_str(), IsPersistent{}); + // Individual Tile Scheduler + run(Shape{}, (name + individual).c_str(), IsPersistent{}); +#elif FP16 + name += " fp16"; + // Persistent Tile Scheduler + run(Shape{}, (name + persistent).c_str(), IsPersistent{}); + // Individual Tile Scheduler + run(Shape{}, (name + individual).c_str(), IsPersistent{}); +#endif +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) { + std::cout + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability major 10) and CUDA 12.8 or greater.\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + if (options.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + else { + hw_info.sm_count = options.sm_count; + } + + run_mla(options, hw_info); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 90b47387..bff609fa 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -35,6 +35,10 @@ set_property( SOURCE 77_blackwell_fmha_gen.cu PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") +set_property( + SOURCE 77_blackwell_mla.cu + PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") + set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen) @@ -48,58 +52,69 @@ set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify) set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap) set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only) -if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))) - if (CUTLASS_NVCC_ARCHS MATCHES 100a) - cutlass_example_add_executable( - 77_blackwell_fmha_fp8 - 77_blackwell_fmha.cu - TEST_COMMAND_OPTIONS - TEST_BASIC - # TEST_CAUSAL - # TEST_VARLEN - # TEST_HDIM64 - # TEST_GQA) - ) - target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8) +set(TEST_MLA_BASIC --b=1 --k=512 --verify) - cutlass_example_add_executable( - 77_blackwell_fmha_gen_fp8 - 77_blackwell_fmha_gen.cu - TEST_COMMAND_OPTIONS - TEST_GEN_BASIC - # TEST_GEN_VARLEN - # TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) - ) - target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8) +if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a)) - cutlass_example_add_executable( - 77_blackwell_fmha_fp16 - 77_blackwell_fmha.cu - TEST_COMMAND_OPTIONS - TEST_BASIC - # TEST_CAUSAL - # TEST_VARLEN - # TEST_HDIM64 - # TEST_GQA) - ) - target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + foreach(PREC fp8 fp16) + string(TOUPPER "${PREC}" PREC_MACRO) - cutlass_example_add_executable( - 77_blackwell_fmha_gen_fp16 - 77_blackwell_fmha_gen.cu - TEST_COMMAND_OPTIONS - TEST_GEN_BASIC - # TEST_GEN_VARLEN - # TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) - ) - target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - endif() + cutlass_example_add_executable( + 77_blackwell_fmha_${PREC} + 77_blackwell_fmha.cu + TEST_COMMAND_OPTIONS + TEST_BASIC + # TEST_CAUSAL + # TEST_VARLEN + # TEST_HDIM64 + # TEST_GQA) + ) + target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO}) + + cutlass_example_add_executable( + 77_blackwell_fmha_gen_${PREC} + 77_blackwell_fmha_gen.cu + TEST_COMMAND_OPTIONS + TEST_GEN_BASIC + # TEST_GEN_VARLEN + # TEST_GEN_HDIM64 + # TEST_GEN_GQA + # TEST_GEN_REMAP + # TEST_GEN_CACHEONLY) + ) + target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO}) + + cutlass_example_add_executable( + 77_blackwell_mla_2sm_${PREC} + 77_blackwell_mla.cu + TEST_COMMAND_OPTIONS + TEST_MLA_BASIC + ) + target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO}) + target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_mla_2sm_cpasync_${PREC} + 77_blackwell_mla.cu + TEST_COMMAND_OPTIONS + TEST_MLA_BASIC + ) + target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC) + target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v) + + cutlass_example_add_executable( + 77_blackwell_mla_b2b_2sm_${PREC} + 77_blackwell_mla.cu + TEST_COMMAND_OPTIONS + TEST_MLA_BASIC + ) + target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B) + target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v) + + endforeach() endif() diff --git a/examples/77_blackwell_fmha/README.md b/examples/77_blackwell_fmha/README.md index 2f4c9c76..c8250a7d 100644 --- a/examples/77_blackwell_fmha/README.md +++ b/examples/77_blackwell_fmha/README.md @@ -22,6 +22,24 @@ The `apply_mask` function is called with the accumulator of the first GEMM and t It is well-suited for applying masks or activations. More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA. +# MLA Inference for Blackwell + +This sample provides code for fused multi-head latent attention inference in +the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64. +It supports fp16, bf16, and fp8 input and output types. + +To accomodate the large output accumulator due to the large latent head dimension, +the sample demonstrates how to leverage 2Sm Blackwell tensor cores. + +Loading can be done via TMA (either without paging or with page size 128), or using `cp.async` +for support of any power-of-two page size less than or equal to 128. +With paging, the code also supports variable sequence length. + +The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel. + +The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16. +For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them. + # Copyright Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/77_blackwell_fmha/common/pow_2.hpp b/examples/77_blackwell_fmha/common/pow_2.hpp new file mode 100644 index 00000000..eca93250 --- /dev/null +++ b/examples/77_blackwell_fmha/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/examples/77_blackwell_fmha/device/sm100_mla.hpp b/examples/77_blackwell_fmha/device/sm100_mla.hpp new file mode 100644 index 00000000..4e098090 --- /dev/null +++ b/examples/77_blackwell_fmha/device/sm100_mla.hpp @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + int max_splits = ceil_div(K, 128); + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + args.split_kv = split_wave_aware; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 00000000..c6a05750 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = max(lse_max, local_lse[i]); + } + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); + } + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + params.scale * lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 00000000..acb89a9d --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2018 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + alignas(2048) cute::array> smem_p; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + + if (args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + return false; + } + if (get<1>(args.problem_shape) <= 0) { + return false; + } + if (args.split_kv <= 0) { + return false; + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + +#if 1 + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } +#endif + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + +#ifndef B2B + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } +#endif + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing +#ifndef B2B + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); +#endif + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + if (epilogue_args.ptr_o_acc != nullptr) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + #endif + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } +#endif + } + } + + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + +#ifdef B2B + row_sum = 1; +#else + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } +#endif + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv + ); + } + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp b/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 00000000..dbcc2ce8 --- /dev/null +++ b/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,160 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel + diff --git a/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp new file mode 100644 index 00000000..29db9074 --- /dev/null +++ b/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void __global__ fmha_mla_reference_kernel( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale softmax_scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ ElementAcc mS[]; + // ElementAcc* mS = reinterpret_cast(mS_mem); + + for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) { + if (mSeq.data() != nullptr) { + K = mSeq(idx_B); + } + + for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + ElementAcc acc = 0; + + for (int idx_D = 0; idx_D < D_latent; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQL(idx_H, idx_D, idx_B); + ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + + for (int idx_D = 0; idx_D < D_rope; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQR(idx_H, idx_D, idx_B); + ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + mS[idx_K] = acc; + } + + __syncthreads(); + + ElementAcc maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < K; idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + +#ifndef B2B + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } +#endif + + __syncthreads(); + + ElementAcc sum = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + sum += mS[idx_K]; + } + + ElementAcc o_scale = 1.0f / sum; +#ifdef B2B + o_scale = 1.0; +#endif + + for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B); + ElementAcc eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_H, idx_D, idx_B) = static_cast(acc * o_scale); + } + + if (threadIdx.x == 0) { + mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void fmha_mla_reference( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + dim3 grid(H, B, 1); + dim3 block(256); + int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16; + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_mla_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + throw std::runtime_error("couldn't perform smem optin"); + } + } + fmha_mla_reference_kernel<<>>( + problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale); + cudaDeviceSynchronize(); + result = cudaGetLastError(); + if (cudaSuccess != result) { + throw std::runtime_error("couldn't execute reference"); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp index e4a01c82..6d833ad1 100644 --- a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp +++ b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -178,3 +178,96 @@ void reference_abs_diff( max_diff = result_host[0]; mean_diff = result_host[1] / static_cast(data.size()); } + +template +__global__ void reference_rel_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]); + if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_rel_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_rel_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu new file mode 100644 index 00000000..32df1146 --- /dev/null +++ b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu @@ -0,0 +1,554 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions: + * mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale. + Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution. + + The kernel leverages: + 1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Sparse Tensor Core MMA Instructions + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + $ ./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" +using namespace cute; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::mx_float8_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 16; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// E matrix configuration. Note, E is used to represent metadata tensor. +using ElementE = uint8_t; // Element type for E matrix operand +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag +using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120; // Kernel schedule policy +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelScheduleType // Mainloop schedule policy + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +// +// Data members +// +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +LayoutE layout_E; +uint64_t seed; +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_A_Decompressed; +cutlass::HostTensor block_E; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// +// Command line options parsing +struct Options { + bool help; + float alpha, beta; + int iterations; + int m, n, k; + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + out << "80a_blackwell_geforce_mxfp8_bf16_sparse_gemm\n\n" + << " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + out << "\n\nExamples:\n\n" + << "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + return out; + } + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} +}; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} +/// Initialize blocks that released to sparse Matrix A and its metadata E +bool initialize_sparse_blocks(const Options &options) { + auto workload = make_shape(options.m, + options.n, + options.k, + 1); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + /// Alias SparseConfig and Compressor + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig, + cutlass::arch::Sm120>; + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + /// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs + CompressorUtility compressor_utility(workload, stride_A); + // Aligned M K dimension size for A and E + int aligned_m_e = compressor_utility.get_metadata_m_physical(); + int aligned_k_e = compressor_utility.get_metadata_k_physical(); + int aligned_m_a = compressor_utility.get_tensorA_m_physical(); + int aligned_k_a = compressor_utility.get_tensorA_k_physical(); + /// Layout A and E + layout_A = SparseConfig::fill_layoutA(workload); + layout_E = SparseConfig::fill_layoutE(workload); + + block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a)); + block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e)); + block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k)); + initialize_block(block_A_Decompressed.host_view(), seed + 2020); + compressor_utility.structure_sparse_zero_mask_fill( + block_A_Decompressed.host_data(), static_cast(seed + 2021)); + block_A_Decompressed.sync_device(); + + /// Use compressor kernel to generate compressed Matrix A and E + cutlass::Status status { cutlass::Status::kSuccess }; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {options.m, options.n, options.k, 1}, + {block_A_Decompressed.device_data(), + stride_A, + block_A.device_data(), + block_E.device_data()}, + {hw_info} + }; + + // Compress A and E + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + block_A.sync_host(); + block_E.sync_host(); + return true; +} +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + using namespace cute; + + // Initial A, E(metadata) and A_compressed blocks + if(!initialize_sparse_blocks(options)) return false; + + // Define B, C and D blocks + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + // Define SFA and SFB tensors layouts + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + return true; +} +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), layout_A, + block_B.device_data(), stride_B, + block_E.device_data(), layout_E, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + return arguments; +} +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + // Comparison + block_D.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + return passed; +} +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + // Initialization + if(!initialize(options)) + { + std::cerr << " Initialization failed! " << std::endl; + exit(-1); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + cudaDeviceSynchronize(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + return 0; +} +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +/////////////////////////////////////////////////////////////////////////////////////////////////// +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 120. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + // + // Parse options + // + Options options; + options.parse(argc, args); + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + return 0; +} +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu new file mode 100644 index 00000000..f3441b56 --- /dev/null +++ b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu @@ -0,0 +1,578 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM120 architecture. + + This example demonstrates a simple way to instantiate and run a narrow precision blockscaled sparse GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + The Blackwell SM120 CUTLASS kernel uses the new Block Scaled Sparse Tensor Core MMA Instructions: + * mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale. + Please see more detail in https://docs.nvidia.com/cuda/parallel-thread-execution. + + The kernel leverages: + 1. Warp-Specialized persistent kernel design that supports cooperative scheduler introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Block Scaled Sparse Tensor Core MMA Instructions + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + + Usage: + $ ./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" +using namespace cute; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// C/D matrix configuration +using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int outputVectorSize = 32; // Vector size for D matrix +using outputScaleFactor = cutlass::float_ue4m3_t; // Scale factor type for D matrix +// E matrix configuration. Note, E is used to represent metadata tensor. +using ElementE = uint8_t; // Element type for E matrix operand +// Kernel functional config +using ElementCompute = float; // Element type for computation inside mainloop and epilogue +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag +using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120; // Kernel schedule policy +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_256>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::SparseTmaWarpSpecializedCooperativeSm120, // Epilogue schedule policy + cutlass::epilogue::fusion::LinCombBlockScaleFactor< // Epilogue fusion to generate nvfp4 output + outputVectorSize, ElementD, ElementAccumulator, outputScaleFactor, LayoutDTag, ElementC> + >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelScheduleType // Mainloop schedule policy + >::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig; +using LayoutSFD = typename SfdOutputCfg::LayoutSF; +// +// Data members +// +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +LayoutSFD layout_SFD; +LayoutE layout_E; +uint64_t seed; +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_A_Decompressed; +cutlass::HostTensor block_E; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +cutlass::HostTensor block_SFD; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +cutlass::HostTensor block_reference_SFD; +cutlass::HostTensor block_Normconst; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// +// Command line options parsing +struct Options { + bool help; + float alpha, beta; + int iterations; + int m, n, k; + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + out << "80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm\n\n" + << " Blackwell MXFP8 Sparse GEMM is a warp specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + out << "\n\nExamples:\n\n" + << "$ " << "./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + return out; + } + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} +}; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} +/// Initialize blocks that released to sparse Matrix A and its metadata E +bool initialize_sparse_blocks(const Options &options) { + auto workload = make_shape(options.m, + options.n, + options.k, + 1); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + /// Alias SparseConfig and Compressor + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA::DataType, + LayoutATag, + SparseConfig, + cutlass::arch::Sm120>; + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + /// Declare compressor_utility to randomly fill zero in Matrix A to match sparsity needs + CompressorUtility compressor_utility(workload, stride_A); + // Aligned M K dimension size for A and E + int aligned_m_e = compressor_utility.get_metadata_m_physical(); + int aligned_k_e = compressor_utility.get_metadata_k_physical(); + int aligned_m_a = compressor_utility.get_tensorA_m_physical(); + int aligned_k_a = compressor_utility.get_tensorA_k_physical(); + /// Layout A and E + layout_A = SparseConfig::fill_layoutA(workload); + layout_E = SparseConfig::fill_layoutE(workload); + + block_A.reset(cutlass::make_Coord(aligned_m_a * aligned_k_a)); + block_E.reset(cutlass::make_Coord(aligned_m_e * aligned_k_e)); + block_A_Decompressed.reset(cutlass::make_Coord(options.m * options.k)); + initialize_block(block_A_Decompressed.host_view(), seed + 2020); + compressor_utility.structure_sparse_zero_mask_fill( + block_A_Decompressed.host_data(), static_cast(seed + 2021)); + block_A_Decompressed.sync_device(); + + /// Use compressor kernel to generate compressed Matrix A and E + cutlass::Status status { cutlass::Status::kSuccess }; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {options.m, options.n, options.k, 1}, + {block_A_Decompressed.device_data(), + stride_A, + block_A.device_data(), + block_E.device_data()}, + {hw_info} + }; + + // Compress A and E + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + block_A.sync_host(); + block_E.sync_host(); + return true; +} +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + using namespace cute; + + // Initial A, E(metadata) and A_compressed blocks + if(!initialize_sparse_blocks(options)) return false; + + // Define B, C and D blocks + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1)); + // Define SFA and SFB tensors layouts + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD)))); + block_Normconst.reset(cutlass::make_Coord(1)); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + block_Normconst.at(cutlass::make_Coord(0)) = 2; + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); + block_SFD.sync_device(); + block_Normconst.sync_device(); + return true; +} +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), layout_A, + block_B.device_data(), stride_B, + block_E.device_data(), layout_E, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data(); + arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data(); + return arguments; +} +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A_Decompressed.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + Tensor tensor_E = make_tensor(make_iterator(block_E.host_data()), layout_E); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + auto tensor_SFD = cute::make_tensor(block_reference_SFD.host_data(), layout_SFD); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D), // TensorD + decltype(tensor_SFD), // TensorSfD + cute::Int, + cutlass::reference::host::SfStrategy::SfDGen + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))}; + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + // Comparison + block_D.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_reference_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + return passed; +} +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + // Initialization + if(!initialize(options)) + { + std::cerr << " Initialization failed! " << std::endl; + exit(-1); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + cudaDeviceSynchronize(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + return 0; +} +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +/////////////////////////////////////////////////////////////////////////////////////////////////// +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 120. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + // + // Parse options + // + Options options; + options.parse(argc, args); + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + return 0; +} +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt new file mode 100644 index 00000000..6a94fb0d --- /dev/null +++ b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 120a) +cutlass_example_add_executable( + 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm + 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu +) + +cutlass_example_add_executable( + 80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm + 80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu +) + +endif() diff --git a/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu b/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu new file mode 100644 index 00000000..f955b8e9 --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu @@ -0,0 +1,869 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Distributed GEMM (DistGEMM) for Blackwell. + + This example runs Tensor Parallel GEMMs using the (experimental) Distributed GEMM API in + CUTLASS. For more information, please refer to README.md. + + Note that Distributed GEMM assumes an any-to-any NVLink network topology. + To check whether your device is compatible, run: + + $ nvidia-smi topo -m + + and make sure there's an any-to-any NVLink topology. It would look like this: + + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 + GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 + GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 + GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 + GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 + GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 + GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 + GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 + GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X + + You should also additionally check if the driver enables peer to peer access: + + $ nvidia-smi topo -p2p r + + Output should be something like this: + + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 + GPU0 X OK OK OK OK OK OK OK + GPU1 OK X OK OK OK OK OK OK + GPU2 OK OK X OK OK OK OK OK + GPU3 OK OK OK X OK OK OK OK + GPU4 OK OK OK OK X OK OK OK + GPU5 OK OK OK OK OK X OK OK + GPU6 OK OK OK OK OK OK X OK + GPU7 OK OK OK OK OK OK OK X + + It is recommended to build this target with the following flag to enable + Grid Dependency Control instructions (GDC) in CUTLASS: + - CUTLASS_ENABLE_GDC_FOR_SM100 + + Example: + + $ mkdir build && cd build + + $ cmake .. -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1 + + $ cd examples/82_blackwell_distributed_gemm + + $ make + + $ ./82_blackwell_distributed_gemm +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +// Distributed GEMM headers +#include "cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp" +#include "cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp" +#include "cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp" + +#include "helper.h" + +// Distributed GEMM helpers +#include "dist_gemm_helpers.h" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Distributed GEMM configuration +///////////////////////////////////////////////////////////////////////////////////////////////// + +// TP size (= number of processors/GPUs) +using TP = _8; +static constexpr int TP_ = TP{}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + +// Distributed GEMM tiling/sharding schedule +// Choices: +// +// * All Gather + GEMM: +// * AllGather1D_TilingCD_RotatingA +// * AllGather1D_TilingCD_RotatingB +// +// * GEMM + Reduce Scatter: +// * ReduceScatter1D_TilingA_RotatingC +// * ReduceScatter1D_TilingB_RotatingC + +using DistSchedule = cutlass::distributed::schedules::AllGather1D_TilingCD_RotatingA; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +using ElementD = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_256,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; +// Shape of the tile computed by each SM +using PerSmTileShape_MNK = Shape<_128, _256, _128>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +// We're going to use the single-device GEMM as reference +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Instantiate Distributed GEMM kernel +using DistGemmKernel = cutlass::distributed::kernel::DistributedGemmKernelWrapper< + GemmKernel, + DistSchedule +>; +using DistGemm = cutlass::distributed::device::DistributedGemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +using HostTensorA = typename cutlass::HostTensor; +using HostTensorB = typename cutlass::HostTensor; +using HostTensorC = typename cutlass::HostTensor; +using HostTensorD = typename cutlass::HostTensor; + +// Reference GEMM tensors +HostTensorA tensor_A; +HostTensorB tensor_B; +HostTensorC tensor_C; +HostTensorD tensor_D; +HostTensorD tensor_ref_D; + +// DistGEMM tensors (multi-device) +HostTensorA tensor_A_arr[TP_]; +HostTensorB tensor_B_arr[TP_]; +HostTensorD tensor_C_arr[TP_]; +HostTensorD tensor_D_arr[TP_]; + +#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 100; + int warmup_iterations = 10; + int m = 16384, n = 106496, k = 16384, l = 1; + float eps = 0.f; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup-iterations", warmup_iterations); + cmd.get_cmd_line_argument("eps", eps); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "82_blackwell_distributed_gemm\n\n" + << " Blackwell Distributed GEMM (DistGEMM). \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch) of the GEMM (default: 1)\n" + << " --alpha= Epilogue scalar alpha (default: 1.0)\n" + << " --beta= Epilogue scalar beta (default: 0.0)\n" + << " --iterations= Number of profiling iterations to perform (default: 100)\n" + << " --warmup-iterations= Number of warmup iterations prior to profiling (default: 10)\n" + << " --eps= Threshold for error compared to reference " + << "GEMM (default: 0.0)\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "82_blackwell_distributed_gemm" << " --m=16384 --n=106496 --k=16384 \n\n"; + + return out; + } + + /// Compute performance in TFLOP/s + double tflops(double runtime_s) const { + + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l / TP_; + double tflop = double(flop) / double(1.0e12); + return tflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double tflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double tflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), tflops(tflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed, + bool is_device_tensor = false) { + + double scope_max, scope_min; + int bits = cutlass::sizeof_bits::value; + + if (bits == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits <= 16) { + scope_max = 2; + scope_min = -2; + } + else { + scope_max = 8; + scope_min = -8; + } + + if (is_device_tensor) { + using Real = typename cutlass::RealType::Type; + cutlass::reference::device::TensorFillRandomUniform( + view, seed, static_cast(scope_max), static_cast(scope_min), 0); + cudaDeviceSynchronize(); + } else { + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l); + + // Setup (reference) GEMM tensors + auto shape_A = cute::select<0,2,3>(problem_shape); + auto shape_B = cute::select<1,2,3>(problem_shape); + auto shape_C = cute::select<0,1,3>(problem_shape); + auto shape_D = cute::select<0,1,3>(problem_shape); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_C); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_D); + + auto a_coord = cutlass::make_Coord(size(shape_A), 1); + auto b_coord = cutlass::make_Coord(size(shape_B), 1); + auto c_coord = cutlass::make_Coord(size(shape_C), 1); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.device_view(), seed + 2022, /* is_device_tensor = */ true); + initialize_tensor(tensor_B.device_view(), seed + 2023, /* is_device_tensor = */ true); + initialize_tensor(tensor_C.device_view(), seed + 2024, /* is_device_tensor = */ true); + + tensor_A.sync_host(); + tensor_B.sync_host(); + tensor_C.sync_host(); + tensor_D.sync_host(); + tensor_ref_D.sync_host(); + + // Set up DistGEMM tensors + auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape); + auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape); + auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape); + auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape); + + auto a_coord_device = cutlass::make_Coord(size(local_shape_A), 1); + auto b_coord_device = cutlass::make_Coord(size(local_shape_B), 1); + auto c_coord_device = cutlass::make_Coord(size(local_shape_C), 1); + + int primary_device_idx; + CUDA_CHECK(cudaGetDevice(&primary_device_idx)); + + // Enable any-to-any access + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + int can_access; + CUDA_CHECK(cudaSetDevice(device_idx)); + for (int peer_idx = 0; peer_idx < TP_; ++peer_idx) { + if (peer_idx != device_idx) { + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device_idx, peer_idx)); + if (not can_access) { + std::cerr << "FAILURE: Device " << device_idx << " can't access device " << peer_idx << "." << + std::endl; + exit(EXIT_FAILURE); + } + CUDA_CHECK(cudaDeviceEnablePeerAccess(peer_idx, 0)); + } + } + + tensor_A_arr[device_idx].resize(a_coord_device); + tensor_B_arr[device_idx].resize(b_coord_device); + tensor_C_arr[device_idx].resize(c_coord_device); + tensor_D_arr[device_idx].resize(c_coord_device); + } + CUDA_CHECK(cudaSetDevice(primary_device_idx)); +} + +/// Commandline options -> Gemm/DistGemm Arguments +using GemmArguments = typename Gemm::Arguments; +GemmArguments gemm_args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {static_cast(options.alpha), static_cast(options.beta)}, + tensor_C.device_data(), stride_C, + tensor_ref_D.device_data(), stride_D + } + }; + + return arguments; +} + +using DistGemmArguments = typename DistGemm::Arguments; +DistGemmArguments dist_gemm_args_from_options( + const Options &options, + int device_idx, + cudaStream_t stream) { + + auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l); + + auto global_A = cute::make_tensor(tensor_A.device_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto global_B = cute::make_tensor(tensor_B.device_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto global_C = cute::make_tensor(tensor_C.device_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + + auto global_A_device_slice = DistSchedule::get_device_slice_A(global_A, device_idx); + auto global_B_device_slice = DistSchedule::get_device_slice_B(global_B, device_idx); + auto global_C_device_slice = DistSchedule::get_device_slice_C(global_C, device_idx); + + auto local_shape_A = DistSchedule::get_local_a_shape(problem_shape); + auto local_shape_B = DistSchedule::get_local_b_shape(problem_shape); + auto local_shape_C = DistSchedule::get_local_c_shape(problem_shape); + auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape); + + auto local_stride_A = cutlass::make_cute_packed_stride(StrideA{}, local_shape_A); + auto local_stride_B = cutlass::make_cute_packed_stride(StrideB{}, local_shape_B); + auto local_stride_C = cutlass::make_cute_packed_stride(StrideC{}, local_shape_C); + auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D); + + auto local_A = cute::make_tensor( + tensor_A_arr[device_idx].device_data(), + make_layout(local_shape_A, local_stride_A)); + auto local_B = cute::make_tensor( + tensor_B_arr[device_idx].device_data(), + make_layout(local_shape_B, local_stride_B)); + auto local_C = cute::make_tensor( + tensor_C_arr[device_idx].device_data(), + make_layout(local_shape_C, local_stride_C)); + auto local_D = cute::make_tensor( + tensor_D_arr[device_idx].device_data(), + make_layout(local_shape_D, local_stride_D)); + + // Copy over tensor tiles for the first iteration + cutlass::device_copy(global_A_device_slice, local_A, stream); + cutlass::device_copy(global_B_device_slice, local_B, stream); + cutlass::device_copy(global_C_device_slice, local_C, stream); + + DistGemmArguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, // mode + problem_shape, // problem shape + { + reinterpret_cast(local_A.data()), + local_A.stride(), + reinterpret_cast(local_B.data()), + local_B.stride() + }, // mainloop + { + { // epilogue.thread + static_cast(options.alpha), + static_cast(options.beta) + }, + reinterpret_cast(local_C.data()), + local_C.stride(), + reinterpret_cast(local_D.data()), + local_D.stride(), + }, // epilogue + {}, // hw_info + {} // scheduler + }; + + return arguments; +} + +// Gathers results, moves back to the original full-sized D tensor on the primary device. +void gather_results(const Options &options, int device_idx, cudaStream_t stream = nullptr) { + + auto problem_shape = cute::make_tuple(options.m, options.n, options.k, options.l); + + // Global dest + auto global_D = cute::make_tensor(tensor_D.device_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto global_D_device_slice = DistSchedule::get_device_slice_D(global_D, device_idx); + + // Device_idx local dest + auto local_shape_D = DistSchedule::get_local_d_shape(problem_shape); + auto local_stride_D = cutlass::make_cute_packed_stride(StrideD{}, local_shape_D); + auto local_D = cute::make_tensor( + tensor_D_arr[device_idx].device_data(), + make_layout(local_shape_D, local_stride_D) + ); + + // Copy to global dest + cutlass::device_copy(local_D, global_D_device_slice, stream); +} + +bool verify(const Options &options) { + tensor_D.sync_host(); + tensor_ref_D.sync_host(); + + bool passed = false; + if (options.eps == 0.f) { + passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + } else { + double err = cutlass::reference::host::TensorRelativeErrorMetric( + tensor_D.host_view(), + tensor_ref_D.host_view()); + passed = err < 1e-5; + } + + if (options.m <= 64 && options.n <= 64) { + std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n"; + std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n"; + } + + return passed; +} + +/// Execute a given example GEMM computation +int run(Options &options) { + + int primary_device_idx; + cudaError_t device_get_result = cudaGetDevice(&primary_device_idx); + if (device_get_result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + initialize(options); + + // Reference single-GPU GEMM + Gemm reference_gemm; + cutlass::device_memory::allocation reference_workspace; + + auto reference_arguments = gemm_args_from_options(options); + size_t reference_workspace_size = Gemm::get_workspace_size(reference_arguments); + reference_workspace = cutlass::device_memory::allocation(reference_workspace_size); + + CUTLASS_CHECK(reference_gemm.can_implement(reference_arguments)); + CUTLASS_CHECK(reference_gemm.initialize(reference_arguments, reference_workspace.get())); + CUTLASS_CHECK(reference_gemm.run()); + + using ElementBarrier = typename DistGemm::ElementBarrier; + using ElementFlag = typename DistGemmKernel::ElementFlag; + + // Set up per-device streams + cudaStream_t stream_arr[TP_]; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + + // Create stream + CUDA_CHECK(cudaStreamCreate(&stream_arr[device_idx])); + } + + // Instantiate DistGEMM + DistGemm dist_gemm_arr[TP_]; // Distributed GEMM array for multiple devices + + // Allocate workspace memory + cutlass::device_memory::allocation workspace_arr[TP_]; + cutlass::device_memory::allocation exclusive_workspace_arr[TP_]; + + // Cross-device workspace pointer array for gemm.initialize() + void * workspace_ptr_arr[TP_]; + void * exclusive_workspace_ptr_arr[TP_]; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + DistGemmArguments arguments_[TP_]; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + + arguments_[device_idx] = dist_gemm_args_from_options(options, device_idx, stream_arr[device_idx]); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DistGemm::get_workspace_size(arguments_[device_idx]); + size_t exclusive_workspace_size = DistGemm::get_exclusive_workspace_size(); + + workspace_arr[device_idx] = cutlass::device_memory::allocation(workspace_size); + exclusive_workspace_arr[device_idx] = cutlass::device_memory::allocation(exclusive_workspace_size); + + // Throw workspace pointers into arrays for gemm.initialize() + workspace_ptr_arr[device_idx] = workspace_arr[device_idx].get(); + exclusive_workspace_ptr_arr[device_idx] = exclusive_workspace_arr[device_idx].get(); + + // Zero out exclusive workspace + cudaMemsetAsync(exclusive_workspace_ptr_arr[device_idx], 0, exclusive_workspace_size, stream_arr[device_idx]); + + cudaDeviceSynchronize(); + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + + // Check if the problem size is supported or not + CUTLASS_CHECK(dist_gemm_arr[device_idx].can_implement(arguments_[device_idx])); + +#if defined(CUTLASS_ENABLE_GDC_FOR_SM100) + bool launch_with_pdl = true; +#else + bool launch_with_pdl = false; +#endif + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(dist_gemm_arr[device_idx].initialize( + arguments_, + workspace_ptr_arr, + exclusive_workspace_ptr_arr, + device_idx, + stream_arr[device_idx], + launch_with_pdl + )); + + cudaDeviceSynchronize(); + } + + // Correctness / Warmup iteration + std::cout << std::endl << " running DistGEMM..." << std::endl; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx])); + } + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx])); + CUDA_CHECK(cudaGetLastError()); + gather_results(options, device_idx); + } + + std::cout << " running DistGEMM finished without runtime errors" << std::endl; + + //// Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + + result.passed = verify(options); + + std::cout << std::endl << " Disposition (eps: " << options.eps << "): " << + (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + float elapsed_ms = 0.f; + + // Warmup + std::cout << " Warming up for " << options.warmup_iterations << " iterations." << std::endl; + for (int warmup_iter = 0; warmup_iter < options.warmup_iterations; ++warmup_iter) { + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx])); + } + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUDA_CHECK(cudaStreamSynchronize(stream_arr[device_idx])); + } + + CUDA_CHECK(cudaSetDevice(primary_device_idx)); + + // Benchmark + std::cout << " Profiling for " << options.iterations << " iterations." << std::endl; + using AtomicBoolean = cuda::atomic; + AtomicBoolean* atomic_flag_ptr; + CUDA_CHECK(cudaHostAlloc(&atomic_flag_ptr, sizeof(AtomicBoolean), cudaHostAllocPortable)); + atomic_flag_ptr->store(false); + + cutlass::DistGpuTimer timer; + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + cutlass::delay_kernel<<<1, 1, 0, stream_arr[device_idx]>>>(atomic_flag_ptr); + CUDA_CHECK(cudaGetLastError()); + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + timer.start(device_idx, stream_arr[device_idx]); + } + + atomic_flag_ptr->store(true); + + for (int profile_iter = 0; profile_iter < options.iterations; ++profile_iter) { + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + CUTLASS_CHECK(dist_gemm_arr[device_idx].run(stream_arr[device_idx])); + } + } + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + CUDA_CHECK(cudaSetDevice(device_idx)); + timer.stop(device_idx, stream_arr[device_idx]); + } + + CUDA_CHECK(cudaSetDevice(primary_device_idx)); + + for (int device_idx = 0; device_idx < TP_; ++device_idx) { + elapsed_ms = max(elapsed_ms, timer.elapsed_millis(device_idx)); + } + + // Compute average runtime and TFLOPs. + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0); + result.tflops = options.tflops(avg_runtime_s); + + auto [local_M, local_N, local_K, local_L] = DistSchedule::get_local_gemm_shape( + cute::make_tuple(options.m, options.n, options.k, options.l)); + + std::cout << std::endl; + std::cout << " TP: " << TP::value << std::endl; + std::cout << " Problem Size: " << + options.m << " x " << + options.n << " x " << + options.k << " x " << + options.l << std::endl; + std::cout << " Local GEMM Problem Size: " << + local_M << " x " << + local_N << " x " << + local_K << " x " << + local_L<< std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " TFLOPS: " << result.tflops << std::endl; + } + + return 0; +} + +#endif // (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example + // and must have compute capability at least 90. + // Some necessary cuda graph APIs were only introduced in CUDA 12.4. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { + std::cerr << "This example requires CUDA 12.4 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + int num_devices; + CUDA_CHECK(cudaGetDeviceCount(&num_devices)); + if (num_devices < TP_) { + std::cerr << "Distributed GEMM is compiled with TP = " << TP::value << ", but " << + "found only " << num_devices << " devices." << + std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture " + << "(compute capability 100), " + << "got compute capability " << props.major * 10 + props.minor << "." + << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) + run(options); +#endif + + return 0; +} diff --git a/examples/82_blackwell_distributed_gemm/CMakeLists.txt b/examples/82_blackwell_distributed_gemm/CMakeLists.txt new file mode 100644 index 00000000..fa8fe9ad --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 82_blackwell_distributed_gemm + 82_blackwell_distributed_gemm.cu + ) diff --git a/examples/82_blackwell_distributed_gemm/README.md b/examples/82_blackwell_distributed_gemm/README.md new file mode 100644 index 00000000..6f6c19b8 --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/README.md @@ -0,0 +1,37 @@ +# Blackwell Distributed GEMM + +This example implements Tensor Parallel GEMMs for the Hopper architecture with the experimental +[Distributed GEMM](../../include/cutlass/experimental/distributed) API in CUTLASS. + +This example requires Blackwell GPUs with an any-to-any NVLink network. +Please refer to [REQUIREMENTS.md](REQUIREMENTS.md) for more information. + +By default, the example assumes 8 GPUs (TP=8) and runs an All Gather + GEMM operation, which rotates +operand A. To run with a different number of GPUs or schedule, please refer to +[82_blackwell_distributed_gemm.cu](82_blackwell_distributed_gemm.cu). + + +## Getting started + +Command line arguments are mostly similar to other examples: + +``` +--m= Sets the M extent of the GEMM +--n= Sets the N extent of the GEMM +--k= Sets the K extent of the GEMM +--l= Sets the L extent (batch) of the GEMM (default: 1) +--alpha= Epilogue scalar alpha (default: 1.0) +--beta= Epilogue scalar beta (default: 0.0) +--iterations= Number of profiling iterations to perform (default: 100) +--warmup-iterations= Number of warmup iterations prior to profiling (default: 10) +--eps= Threshold for error compared to reference GEMM (default: 0.0) +``` + +Sample run command: + +```bash +./82_blackwell_distributed_gemm --m=16384 --n=106496 --k=16384 --warmup-iterations=10 --iterations=100 +``` + +This example follows the [Hopper example](../65_distributed_gemm/) very closely, and only differs in the base GEMM kernel. For +more information you can refer to [that example](../65_distributed_gemm/README.md). diff --git a/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md b/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md new file mode 100644 index 00000000..3943716b --- /dev/null +++ b/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md @@ -0,0 +1,86 @@ +# Blackwell Distributed GEMM + +## Requirements + +### Build +Make sure to set up CUTLASS with +support for [Programmatic Dependent Launch (PDL)](../../media/docs/dependent_kernel_launch.md), +that is with the `CUTLASS_ENABLE_GDC_FOR_SM100` flag. + +```bash +cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1 +``` + +### Minimum software + +Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required. +This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary +CUDA graph APIs. + +### Hardware / driver settings + +This example requires Blackwell GPUs with NVLink network. + +If you're not sure, first run the following command and make sure your GPU +compute capability is 10.0: + +```bash +nvidia-smi --query-gpu=name,compute_cap --format=csv +``` + +Sample output: + +``` +name, compute_cap +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +NVIDIA B200, 10.0 +``` + + +Then you should make sure there is an NVLink network by checking the GPU network topology, +and making sure there's `NV*` links between every pair of GPUs: + +```bash +nvidia-smi topo -m +``` + +Sample output: + +``` + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 +GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 +GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 +GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 +GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 +GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 +GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 +GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 +GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X +``` + +Finally, check if the driver enables peer to peer access, which should usually be the case, +but it's good to check anyway: + +```bash +nvidia-smi topo -p2p r +``` + +Sample output: + +``` + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 +GPU0 X OK OK OK OK OK OK OK +GPU1 OK X OK OK OK OK OK OK +GPU2 OK OK X OK OK OK OK OK +GPU3 OK OK OK X OK OK OK OK +GPU4 OK OK OK OK X OK OK OK +GPU5 OK OK OK OK OK X OK OK +GPU6 OK OK OK OK OK OK X OK +GPU7 OK OK OK OK OK OK OK X +``` diff --git a/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu b/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu new file mode 100644 index 00000000..d4280472 --- /dev/null +++ b/examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu @@ -0,0 +1,607 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A FP16 sparse GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. + + The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features: + + 1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a) + which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA). + + Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + 2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a). + Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the + Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm --m=8192 --n=8192 --k=8192 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = half_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 2 * 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k + +// E matrix config +using ElementE = cute::uint8_t; + +// B matrix configuration +using ElementB = half_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = float; // Element type for D matrix operand +using ElementC = float; // Element type for C matrix operand +using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +using LayoutTagD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassSparseTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_64>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveoutEpi, + cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100 + >::CollectiveOp; + +using ProblemShape = Shape; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutTagA, + ElementB, + LayoutTagB, + ElementC, + LayoutTagC, + ElementAccumulator, + ElementAccumulator>; + +// Layouts +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Compressor +// +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +/// Initialization +LayoutA layout_A; +LayoutE layout_E; +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +uint64_t seed; + +ProblemShape problem_shape; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_A_compressed; +cutlass::DeviceAllocation block_E; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(8192), n(8192), k(8192), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "83_blackwell_sparse_gemm\n\n" + << " Blackwell FP16 Sparse GEMM example.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "83_blackwell_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } + else if constexpr (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } + else { + scope_max = Element(8); + scope_min = Element(-8); + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +/// Make A structured sparse by replacing elements with 0 and compress it +bool sparsify_and_compress() +{ + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + block_A_compressed.reset(M * KAlignedAC * L); + block_E.reset(MAlignedE * KAlignedE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + // Random 50% fill zero is performed on host + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast(seed + 2024)); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + problem_shape, + { block_A.get(), + stride_A, + block_A_compressed.get(), + block_E.get() }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + if (status != cutlass::Status::kSuccess) { + return false; + } + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + + // Compress row A and get A_compress and E + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + if (not sparsify_and_compress()) { + return false; + }; + + // Build the compressed/metadata layouts + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + + return true; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E }, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + auto init_pass = initialize(options); + if (not init_pass) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (not result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (not (props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/83_blackwell_sparse_gemm/CMakeLists.txt b/examples/83_blackwell_sparse_gemm/CMakeLists.txt new file mode 100644 index 00000000..765ef4c4 --- /dev/null +++ b/examples/83_blackwell_sparse_gemm/CMakeLists.txt @@ -0,0 +1,38 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +cutlass_example_add_executable( + 83_blackwell_sparse_gemm + 83_blackwell_sparse_gemm.cu +) + +endif() diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu b/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu new file mode 100644 index 00000000..d2d87c46 --- /dev/null +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu @@ -0,0 +1,693 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 Sparse GEMM on the NVIDIA Blackwell SM100 architecture. + + The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced + on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma) + and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 83_blackwell_sparse_gemm, this kernel leverages: + 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e2m1_t; +using ElementAPair = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k + +// E matrix config +using ElementE = cute::uint8_t; +using LayoutTagE = LayoutTagA; + +// B matrix configuration +using ElementB = cutlass::float_e2m1_t; +using ElementBPair = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// SF +using ElementSF = typename ElementAPair::ScaleFactorType; + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape = Shape<_256,_128,_256>; +// Shape of the threadblocks in a cluster +using ClusterShape = Shape<_2,_1,_1>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized2SmNvf4 + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementAPair, LayoutTagA, AlignmentA, + ElementBPair, LayoutTagB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveoutEpi, + cutlass::gemm::KernelSparseTmaWarpSpecialized2SmNvf4Sm100 + >::CollectiveOp; + +using ProblemShape = Shape; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// +// Blockscale +// +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; +using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; +using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Compressor +// +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; + +typename LayoutTagA::Stride stride_factor_A; +typename LayoutTagB::Stride stride_factor_B; +typename LayoutTagE::Stride stride_factor_E; +typename LayoutTagC::Stride stride_factor_C; +typename LayoutTagD::Stride stride_factor_D; + +uint64_t seed; + +ProblemShape problem_shape; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_A_compressed; +cutlass::HostTensor tensor_E; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_D; +cutlass::HostTensor reference_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(1024), n(1024), k(1024), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "84a_blackwell_nvfp4_bf16_sparse_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input <= 8) { + if constexpr (cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + + // * Get A B C D size + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape); + + // * Get ACompress & E size + CompressorUtility compressor_utility(problem_shape, stride_A); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l)); + + // * Get SFA & SFB size + auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(options.m, Blk_MN{}); + auto n_blks = cutlass::ceil_div(options.n, Blk_MN{}); + + // * Allocate Tensor + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto d_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D), false); + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + // * Random init + initialize_tensor(tensor_A.host_view(), seed + 2021); + initialize_tensor(tensor_B.host_view(), seed + 2022); + initialize_tensor(tensor_C.host_view(), seed + 2023); + initialize_tensor(tensor_SFA.host_view(), seed + 2024); + initialize_tensor(tensor_SFB.host_view(), seed + 2025); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + // * Random fill 50% A with zero + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + // * Compress + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + {tensor_A.device_data(), + stride_A, + tensor_A_compressed.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + if (status != cutlass::Status::kSuccess) { + return false; + } + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + tensor_E.sync_host(); + tensor_A_compressed.sync_host(); + + return true; +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { + reinterpret_cast(tensor_A_compressed.device_data()), layout_A, + reinterpret_cast(tensor_B.device_data()), stride_B, + tensor_E.device_data(), layout_E, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB + }, + { + {options.alpha, options.beta}, + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + + // Create the arguments for host reference implementation + auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A); + auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(options.n, options.k, options.l), stride_B)); + auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettMainloopParams< + ElementAccumulator, + decltype(A), + decltype(B), + decltype(SFA), + decltype(SFB)> mainloop_params{A, SFA, B, SFB}; + + auto C = make_tensor(make_iterator(tensor_C.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_C)); + auto D = make_tensor(make_iterator(reference_D.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_D)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(C), // TensorC + decltype(D) // TensorD + > epilogue_params{ + options.alpha, + options.beta, + C, + D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + auto init_pass = initialize(options); + if (not init_pass) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (not result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (not (props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu b/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu new file mode 100644 index 00000000..a23af158 --- /dev/null +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu @@ -0,0 +1,695 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A Narrow Precision Sparse GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled MXFP8 Sparse GEMM on the NVIDIA Blackwell SM100 architecture. + + The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced + on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma) + and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Similar to 83_blackwell_sparse_gemm, this kernel leverages: + 1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/). + + 2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM + which allows us to decouple the execution of MMA and epilogue into separate warps. + + 3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + + Usage: + $ ./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; +using ElementAPair = cutlass::mx_float8_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 64; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes), 2x for compress along k + +// E matrix config +using ElementE = cute::uint8_t; +using LayoutTagE = LayoutTagA; + +// B matrix configuration +using ElementB = cutlass::float_e2m1_t; +using ElementBPair = cutlass::mx_float4_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// SF +using ElementSF = typename ElementAPair::ScaleFactorType; + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutTagC = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutTagD = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = (16 * 8) / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledSparseTensorOp; // Operator class tag + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_256,_128,_256>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_2,_1,_1>; + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementD, LayoutTagD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4 + >::CollectiveOp; + +// Build the mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementAPair, LayoutTagA, AlignmentA, + ElementBPair, LayoutTagB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveoutEpi, + cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100 + >::CollectiveOp; + +using ProblemShape = Shape; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// +// Blockscale +// +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; +using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; +using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. + +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Compressor +// +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; + +typename LayoutTagA::Stride stride_factor_A; +typename LayoutTagB::Stride stride_factor_B; +typename LayoutTagE::Stride stride_factor_E; +typename LayoutTagC::Stride stride_factor_C; +typename LayoutTagD::Stride stride_factor_D; + +uint64_t seed; + +ProblemShape problem_shape; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_A_compressed; +cutlass::HostTensor tensor_E; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_D; +cutlass::HostTensor reference_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(1024), n(1024), k(1024), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "84b_blackwell_mixed_mxfp8_bf16_sparse_gemm\n\n" + << " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +void initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input <= 8) { + if constexpr (cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(const Options &options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + + // * Get A B C D size + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape); + + // * Get ACompress & E size + CompressorUtility compressor_utility(problem_shape, stride_A); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, KAlignedAC, options.l)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, options.l)); + + // * Get SFA & SFB size + auto k_blks = cutlass::ceil_div(options.k, cute::size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(options.m, Blk_MN{}); + auto n_blks = cutlass::ceil_div(options.n, Blk_MN{}); + + // * Allocate Tensor + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto e_coord = cutlass::make_Coord(MAlignedE * options.l, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * options.l, KAlignedAC); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto d_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * options.l, k_blks * Blk_SF{}); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_compressed.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(d_coord, stride_factor_D), false); + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + // * Random init + initialize_tensor(tensor_A.host_view(), seed + 2021); + initialize_tensor(tensor_B.host_view(), seed + 2022); + initialize_tensor(tensor_C.host_view(), seed + 2023); + initialize_tensor(tensor_SFA.host_view(), seed + 2024); + initialize_tensor(tensor_SFB.host_view(), seed + 2025); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + // * Random fill 50% A with zero + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + // * Compress + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + {tensor_A.device_data(), + stride_A, + tensor_A_compressed.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + if (status != cutlass::Status::kSuccess) { + return false; + } + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + return false; + } + + tensor_E.sync_host(); + tensor_A_compressed.sync_host(); + + return true; +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { + reinterpret_cast(tensor_A_compressed.device_data()), layout_A, + reinterpret_cast(tensor_B.device_data()), stride_B, + tensor_E.device_data(), layout_E, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB + }, + { + {options.alpha, options.beta}, + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + + // Create the arguments for host reference implementation + auto A = make_tensor(make_iterator(tensor_A.host_data()), layout_A); + auto SFA = make_tensor(tensor_SFA.host_data(), layout_SFA); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(options.n, options.k, options.l), stride_B)); + auto SFB = make_tensor(tensor_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettMainloopParams< + ElementAccumulator, + decltype(A), + decltype(B), + decltype(SFA), + decltype(SFB)> mainloop_params{A, SFA, B, SFB}; + + auto C = make_tensor(make_iterator(tensor_C.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_C)); + auto D = make_tensor(make_iterator(reference_D.host_data()), + make_layout(make_shape(options.m, options.n, options.l), stride_D)); + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementScalingFactor + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(C), // TensorC + decltype(D) // TensorD + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(tensor_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + auto init_pass = initialize(options); + if (not init_pass) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (not result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (not (props.major == 10 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/84_blackwell_narrow_precision_sparse_gemm/CMakeLists.txt b/examples/84_blackwell_narrow_precision_sparse_gemm/CMakeLists.txt new file mode 100644 index 00000000..751590b7 --- /dev/null +++ b/examples/84_blackwell_narrow_precision_sparse_gemm/CMakeLists.txt @@ -0,0 +1,41 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) +cutlass_example_add_executable( + 84a_blackwell_nvfp4_bf16_sparse_gemm + 84a_blackwell_nvfp4_bf16_sparse_gemm.cu + ) + +cutlass_example_add_executable( + 84b_blackwell_mixed_mxfp8_bf16_sparse_gemm + 84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu + ) +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0f03cd9b..bfee2c3c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -158,7 +158,11 @@ foreach(EXAMPLE 77_blackwell_fmha 78_blackwell_emulated_bf16x9_gemm 79_blackwell_geforce_gemm + 80_blackwell_geforce_sparse_gemm 81_blackwell_gemm_blockwise + 82_blackwell_distributed_gemm + 83_blackwell_sparse_gemm + 84_blackwell_narrow_precision_sparse_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/examples/README.md b/examples/README.md index 92779c07..3f79df9a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -286,6 +286,18 @@ Blackwell SM120 MMA kernel targeting GeForce RTX 50 series CUDA Cores +* [80_blackwell_geforce_sparse_gemm](80_blackwell_geforce_sparse_gemm/) + + Blackwell SM120 sparse MMA kernel targeting GeForce RTX 50 series CUDA Cores + +* [83_blackwell_sparse_gemm](83_blackwell_sparse_gemm) + + Blackwell SM100 Sparse Gemm kernel + +* [84_blackwell_narrow_precision_sparse_gemm](84_blackwell_narrow_precision_sparse_gemm) + + Blackwell Block Scaled SM100 Sparse Gemm kernel + # CuTe - Programming Examples Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/). diff --git a/examples/65_distributed_gemm/util/benchmark.h b/examples/common/dist_gemm_helpers.h similarity index 69% rename from examples/65_distributed_gemm/util/benchmark.h rename to examples/common/dist_gemm_helpers.h index 66a0dbb5..ef258e69 100644 --- a/examples/65_distributed_gemm/util/benchmark.h +++ b/examples/common/dist_gemm_helpers.h @@ -44,6 +44,11 @@ #include #include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/cuda_host_adapter.hpp" + namespace cutlass { @@ -115,4 +120,46 @@ struct DistGpuTimer { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Generic device-to-device data movement kernel based for CuTe tensors. +/// +/// NOTE: this kernel assigns one element copy to every thread, and is by no means +/// an efficient way of copying tensors. It should only be used for convenience in +/// reference checks. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream); + + +template +__global__ void device_copy_kernel(TensorSource const tensor_source, + TensorDestination tensor_destination) { + auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + using ElementSrc = typename TensorSource::value_type; + using ElementDst = typename TensorDestination::value_type; + NumericConverter converter; + if (linear_idx < size(tensor_source)) { + tensor_destination(linear_idx) = converter(tensor_source(linear_idx)); + } +} + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream) { + + assert(tensor_source.size() == tensor_destination.size()); + + auto numel = tensor_source.size(); + static constexpr int NumThreads = 128; + auto grid_size = cute::ceil_div(numel, NumThreads); + + dim3 grid(grid_size); + dim3 block(NumThreads); + device_copy_kernel<<>>(tensor_source, tensor_destination); +} + } //namespace cutlass diff --git a/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp b/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp index ad9abfb3..7ac336a3 100644 --- a/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp +++ b/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp @@ -340,7 +340,7 @@ public: base_args.epilogue.thread, reinterpret_cast(tensor_c_iter.data()), tensor_c_iter.stride(), - reinterpret_cast(tensor_d_iter.data()), + reinterpret_cast(tensor_d_iter.data()), tensor_d_iter.stride() }; diff --git a/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp b/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp index a9a40cfe..b2900310 100644 --- a/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp +++ b/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp @@ -82,7 +82,7 @@ struct DistributedGemmKernelWrapper< using BaseArguments = typename BaseKernel::Arguments; using BaseParams = typename BaseKernel::Params; - static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now."); + //static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now."); static_assert(not cute::is_same_v, "DistributedGEMM epilogues must have a source."); using ElementFlag = uint32_t; diff --git a/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/include/cutlass/gemm/collective/builders/sm1xx_common.inl index cb9e74c1..6cdd76d2 100644 --- a/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -189,6 +189,100 @@ template< bool Is2sm = false > constexpr bool sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(){ + // * 1SM Dense + // * A_K(t) : TileShape_K % 128 == 0 + // * A_M(n) : TileShape_M % 128 == 0 + // * B_N(t) : TileSize_N % 128 == 0 + // * B_K(n) : TileSize_K % 128 == 0 + // + // * 2SM Dense + // * A_K(t) : TileShape_K % 128 == 0 + // * A_M(n) : TileShape_M % 128 == 0 + // * B_N(t) : TileSize_N % 256 == 0 + // each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned. + // full tile_n needs to be 256 elts aligned + // * B_K(n) : TileShape_K % 128 == 0 + // + // * 1SM Sparse + // * A_K(t) : TileShape_K % 256 == 0 + // num of physical elems needs to be 128 elts aligned + // num of logical elems needs to be 256 elts aligned + // * A_M(n) : TileShape_M % 128 == 0 + // * B_N(t) : TileSize_N % 128 == 0 + // * B_K(n) : TileSize_K % 128 == 0 + // + // * 2SM Sparse + // * A_K(t) : TileShape_K % 256 == 0 + // num of physical elems needs to be 128 elts aligned + // num of logical elems needs to be 256 elts aligned + // * A_M(n) : TileShape_M % 128 == 0 + // * B_N(t) : TileSize_N % 256 == 0 + // each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned. + // full tile_n needs to be 256 elts aligned + // * B_K(n) : TileShape_K % 128 == 0 + // + // * Valid TileShape_MNK Dense + // * Notation: + // mma_instruction_tile_shape-cta_tile_shape + // * s128x128x64 + // s128x128x32_128x128x128_nn YES + // s128x128x32_128x128x128_nt YES + // s128x128x32_128x128x128_tn YES + // s128x128x32_128x128x128_tt YES + // * s128x256x64 + // s128x256x32_128x256x128_nn YES + // s128x256x32_128x256x128_nt YES + // s128x256x32_128x256x128_tn YES + // s128x256x32_128x256x128_tt YES + // * s256x128x64 + // s256x128x32_256x128x128_nn YES + // s256x128x32_256x128x128_nt NO (2SM B_N TileSize_N % 256 != 0) + // s256x128x32_256x128x128_tn YES + // s256x128x32_256x128x128_tt NO (2SM B_N TileSize_N % 256 != 0) + // * s256x256x64 + // s256x256x32_256x256x128_nn YES + // s256x256x32_256x256x128_nt YES + // s256x256x32_256x256x128_tn YES + // s256x256x32_256x256x128_tt YES + // + // * Valid TileShape_MNK Sparse + // * s128x128x64 + // s128x128x64_128x128x128_nn YES + // s128x128x64_128x128x128_nt YES + // s128x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0) + // s128x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0) + // s128x128x64_128x128x256_nn YES + // s128x128x64_128x128x256_nt YES + // s128x128x64_128x128x256_tn YES + // s128x128x64_128x128x256_tt YES + // * s128x256x64 + // s128x256x64_128x256x128_nn YES + // s128x256x64_128x256x128_nt YES + // s128x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0) + // s128x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0) + // s128x256x64_128x256x256_nn YES + // s128x256x64_128x256x256_nt YES + // s128x256x64_128x256x256_tn YES + // s128x256x64_128x256x256_tt YES + // * s256x128x64 + // s256x128x64_128x128x128_nn YES + // s256x128x64_128x128x128_nt NO (2SM B_N TileSize_N % 256 != 0) + // s256x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0) + // s256x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0) + // s256x128x64_128x128x256_nn YES + // s256x128x64_128x128x256_nt NO (2SM B_N TileSize_N % 256 != 0) + // s256x128x64_128x128x256_tn YES + // s256x128x64_128x128x256_tt NO (2SM B_N TileSize_N % 256 != 0) + // * s256x256x64 + // s256x256x64_128x256x128_nn YES + // s256x256x64_128x256x128_nt YES + // s256x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0) + // s256x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0) + // s256x256x64_128x256x256_nn YES + // s256x256x64_128x256x256_nt YES + // s256x256x64_128x256x256_tn YES + // s256x256x64_128x256x256_tt YES + [[maybe_unused]] constexpr int TileShape_M = Is2sm ? size<0>(TileShape_MNK{}) / 2 : size<0>(TileShape_MNK{}); [[maybe_unused]] constexpr int TileShape_N = size<1>(TileShape_MNK{}); [[maybe_unused]] constexpr int TileShape_K = size<2>(TileShape_MNK{}); diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index 5714151f..25a68671 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -432,6 +432,10 @@ public: init_M = get<0>(problem_shape_MNK); init_N = get<1>(problem_shape_MNK); init_K = get<2>(problem_shape_MNK); + if constexpr (SwapAB) { + init_M = get<1>(problem_shape_MNK); + init_N = get<0>(problem_shape_MNK); + } if constexpr (not SwapAB) { dA = args.dA; @@ -491,7 +495,7 @@ public: : args_setup(args.ptr_A, args.ptr_B); } else if constexpr (ModeHasScales) { - auto scale_k = 1; + auto scale_k = ceil_div(init_K, args.chunk_size); ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); StrideScale dS{}; Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS)); @@ -595,7 +599,7 @@ public: } else if constexpr (ModeHasScales) { const int scale_mn = SwapAB ? N : M; - const int scale_k = (K + args.chunk_size - 1) / args.chunk_size; + const int scale_k = ceil_div(K, args.chunk_size); constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0)); @@ -659,14 +663,15 @@ public: return cute::make_tuple(gA_mkl, gB_nkl); } else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; auto scale_k = mainloop_params.scale_k; - Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(scale_mn,scale_k,L)); Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(scale_mn,scale_k,L)); Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); } @@ -1217,8 +1222,8 @@ public: Params const& mainloop_params, int32_t next_group, ProblemShape_MNKL problem_shape_mnkl) { - const uint32_t M = get<0>(problem_shape_mnkl); - const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t M = (SwapAB? get<1>(problem_shape_mnkl) : get<0>(problem_shape_mnkl)); + const uint32_t N = (SwapAB? get<0>(problem_shape_mnkl) : get<1>(problem_shape_mnkl)); const uint32_t K = get<2>(problem_shape_mnkl); // Replace all dims for consistency @@ -1245,14 +1250,14 @@ public: if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { NonVoidElementScale const* ptr_S = nullptr; - auto scale_k = 1; + auto scale_k = ceil_div(K, mainloop_params.chunk_size); Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, prob_shape_scale, prob_stride_scale); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { ElementZero const* ptr_Z = nullptr; - auto scale_k = 1; + auto scale_k = ceil_div(K, mainloop_params.chunk_size); Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, prob_shape_zero, prob_stride_zero); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index acc183d0..a3f35c5c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -426,7 +426,7 @@ public: return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB }; } else if constexpr (ModeHasScales) { - auto scale_k = (K + args.group_size - 1) / args.group_size; + auto scale_k = ceil_div(K, args.group_size); ElementScale const* ptr_S = args.ptr_S; StrideScale dS = args.dS; Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS)); @@ -483,7 +483,7 @@ public: } else if constexpr (ModeHasScales) { const int scale_mn = SwapAB ? N : M; - const int scale_k = (K + args.group_size - 1) / args.group_size; + const int scale_k = ceil_div(K, args.group_size); constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; check_aligned_S = cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), args.dS); check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); diff --git a/include/cutlass/pipeline/sm100_pipeline.hpp b/include/cutlass/pipeline/sm100_pipeline.hpp index 4ebd8b5d..44b9d4d4 100644 --- a/include/cutlass/pipeline/sm100_pipeline.hpp +++ b/include/cutlass/pipeline/sm100_pipeline.hpp @@ -622,6 +622,11 @@ public: impl_.producer_acquire(state, barrier_token); } + CUTLASS_DEVICE + void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) { + impl_.producer_expect_transaction(state, transaction_bytes); + } + // NOP for TMA based mainloop CUTLASS_DEVICE void producer_commit(PipelineState state, uint32_t bytes) { diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 6b766fc2..c253f7ff 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -452,6 +452,11 @@ public: return producer_get_barrier(state.index()); } + CUTLASS_DEVICE + void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) { + producer_expect_transaction(state.index(), transaction_bytes); + } + //////////////////// // Consumer APIs //////////////////// @@ -519,6 +524,14 @@ private: #endif } + CUTLASS_DEVICE + void producer_expect_transaction(uint32_t stage, uint32_t transaction_bytes) { + detail::pipeline_check_is_producer(params_.role); + if (params_.is_leader) { + full_barrier_ptr_[stage].expect_transaction(transaction_bytes); + } + } + // NOP for TMA based mainloop CUTLASS_DEVICE void producer_commit(uint32_t stage, uint32_t bytes) { diff --git a/media/docs/cpp/blackwell_functionality.md b/media/docs/cpp/blackwell_functionality.md index e751a124..582899d3 100644 --- a/media/docs/cpp/blackwell_functionality.md +++ b/media/docs/cpp/blackwell_functionality.md @@ -9,15 +9,15 @@ efficient SM100 GEMM kernels targeting these new mma instructions. Blackwell SM100 has 7 new `tcgen05.mma` instructions. These instructions are 2x to 4x faster then Hopper Architecture's WGMMA instructions. -| Ptx Instruction | Throughput | Notes | -|----------------------------------------------------------------------------------|----------------------------|-------| -|tcgen05.mma.cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts | -|tcgen05.mma.cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts | -|tcgen05.mma.cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts | -|tcgen05.mma.cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts | -|tcgen05.mma.cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts | -|tcgen05.mma.cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts | -|tcgen05.mma.cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts | +| Ptx Instruction | Throughput | Notes | +|---------------------------------------------------------------------------------------|----------------------------|-------| +|tcgen05.mma(.sp).cta_group::[1\|2].kind::tf32 | 2x Hopper Tf32 Tensor Core | MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts | +|tcgen05.mma(.sp).cta_group::[1\|2].kind::f16 | 2x Hopper Fp16 Tensor Core | MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts | +|tcgen05.mma(.sp).cta_group::[1\|2].kind::i8 | 2x Hopper I8 Tensor Core | MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts | +|tcgen05.mma(.sp).cta_group::[1\|2].kind::f8f6f4 | 2x Hopper Fp8 Tensor Core | Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts | +|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf8f6f4.block_scale | 2x Hopper Fp8 Tensor Core | Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts | +|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf4.block_scale | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts | +|tcgen05.mma(.sp).cta_group::[1\|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X\|4X] | 4x Hopper Fp8 Tensor Core | Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts | For more detailed information see [`tcgen05.mma` PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensorcore-5th-generation-family-instructions). @@ -27,7 +27,7 @@ For more detailed information see [`tcgen05.mma` PTX documentation](https://docs Instructions with `kind` modifiers `mxf8f6f4`, `mxf4`, and `nvf4mxf4` perform matrix multiplication operations with scale factors of the form $D = C +( A \times SFA) * (B \times SFB)$. Scale factors are applied to GEMM-K dimension such that -every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor. For example, an $M\times K$, +every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor (32 or 64 elements for sparse as sparse gemm compress 2x along k-dim). For example, an $M\times K$, $A$ matrix has an associated $M \times \lceil K/32 \rceil$ SFA matrix; and an $N\times K$ $B$, matrix has an associated $N \times \lceil K/32 \rceil$ SFB matrix. For block scaled GEMMs, an entry of output D matrix is $D_{ij} = C_{ij} + \sum_{k} (A_{i,k} \times SFA_{i,k/SV}) \times (B_{j,k}\times SFB_{j,k/SV})$, in index notation, we SV is the scale factor vector size (16 or 32). @@ -57,12 +57,12 @@ See [PTX documentation for narrow precision data types](https://docs.nvidia.com/ Block scaled MMAs use `mx` and `nv` types which are a pair of float8_t, float6_t, float4_t with 2 of the scale factor data types with a predetermined scale factor vector size. `mx` types follow OCP specification (see [OCP Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)). The following types provided by CUTLASS can be used as inputs to collective builders to generate the block scaled kernels: **Blackwell Block Scaled Narrow Precision Data Types** -| Mx/Nv Data Type |Scale Factor Type | SF Vector Size | OCP Compliant | -|----------------------------|------------------|----------------|---------------| -| mx_float8_t\ |float_ue8m0_t |32 | Yes | -| mx_float6_t\ |float_ue8m0_t |32 | Yes | -| mx_float4_t |float_ue8m0_t |32 | Yes | -| nv_float4_t |float_ue4m3_t |16 | No | +| Mx/Nv Data Type |Scale Factor Type | SF Vector Size (Dense) | SF Vector Size (Sparse)| OCP Compliant | +|----------------------------|------------------|------------------------|------------------------|---------------| +| mx_float8_t\ |float_ue8m0_t |32 |64 | Yes | +| mx_float6_t\ |float_ue8m0_t |32 |64 | Yes | +| mx_float4_t |float_ue8m0_t |32 |64 | Yes | +| nv_float4_t |float_ue4m3_t |16 |32 | No | ## Layouts, Tensor Alignment Requirements to Target `tcgen05.mma` Instructions @@ -74,13 +74,18 @@ For legacy types (`tf32`, `f16`, `bf16`, `i8` and `u8`) alignment requirements f All four layouts (TT, NN, NT, TT) are supported for all legacy data types. **Table 1: Valid Data Type, Alignment, and Layout Combinations For MMAs with Legacy Types** -| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | -|-------------------------------|------------|------------|----------------|-------------|-------------|-------------------------|-----------| -|1 | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | | -|2 | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| -|3 | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| -|4 | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| -|5 | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu)| +| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | +|-------------------------------|----------------|------------|------------|----------------|------------------|-------------|-------------------------|---------- | +|[1](#legacy_rows) | Dense | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 | 4 | tf32 | | +|[2](#legacy_rows) | Dense | half_t | half_t | TN, NN, NT, TT | 8 | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu) | +|[3](#legacy_rows) | Dense | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu)| +|[4](#legacy_rows) | Dense | int8_t | int8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu) | +|[5](#legacy_rows) | Dense | uint8_t | uint8_t | TN, NN, NT, TT | 16 | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu) | +|[6](#legacy_rows) | Sparse | tfloat32_t | tfloat32_t | TN, NN, NT, TT | 4 (N) / 8 (T) | 4 | tf32 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f32_f32_f32_f32_f32_tfmma.cu) | +|[7](#legacy_rows) | Sparse | half_t | half_t | TN, NN, NT, TT | 8 (N) / 16 (T) | 8 | f16 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f16_f16_f32_f16_f16_hmma.cu) | +|[8](#legacy_rows) | Sparse | bfloat16_t | bfloat16_t | TN, NN, NT, TT | 8 (N) / 16 (T) | 8 | f16 | [Similar to half_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f16_f16_f32_f16_f16_hmma.cu) | +|[9](#legacy_rows) | Sparse | int8_t | int8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | i8 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_s8_s8_s32_s8_s8_imma.cu) | +|[10](#legacy_rows) | Sparse | uint8_t | uint8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | i8 | [Similar to int8_t unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_s8_s8_s32_s8_s8_imma.cu) | For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every `tcgen05.mma` instructions. Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision @@ -91,203 +96,298 @@ Below tables list valid layout, and alignment values for each A and B data type instructions supported by CUTLASS. **Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling** -| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | -|-------------------------------|----------|----------|----------------|-------------|-------------|-------------------------|-----------| -|[1](#nonbs_rows_1_2_3_6) | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[2](#nonbs_rows_1_2_3_6) | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[3](#nonbs_rows_1_2_3_6) | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[4](#nonbs_rows_4_7) | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | -|[5](#nonbs_rows_5_8) | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | -|[6](#nonbs_rows_1_2_3_6) | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | -|[7](#nonbs_rows_4_7) | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | -|[8](#nonbs_rows_5_8) | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | -|[9](#nonbs_rows_9) | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)| +| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind | Unit Test | +|-------------------------------|----------------|----------|----------|----------------|-------------------|-------------|-------------------------|-----------| +|[1](#nonbs_rows_1_2_3_6) | Dense | float4_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[2](#nonbs_rows_1_2_3_6) | Dense | float4_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[3](#nonbs_rows_1_2_3_6) | Dense | float6_t | float4_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[4](#nonbs_rows_4_7) | Dense | float4_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | +|[5](#nonbs_rows_5_8) | Dense | float8_t | float4_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | +|[6](#nonbs_rows_1_2_3_6) | Dense | float6_t | float6_t | TN, NN, NT, TT | 128 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu)
[NN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu)
[TT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu) | +|[7](#nonbs_rows_4_7) | Dense | float6_t | float8_t | TN, NN, NT, TT | 128 | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu) | +|[8](#nonbs_rows_5_8) | Dense | float8_t | float6_t | TN, NN, NT, TT | 16 | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu) | +|[9](#nonbs_rows_9) | Dense | float8_t | float8_t | TN, NN, NT, TT | 16 | 16 | f8f6f4 | [Unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu)| +|[10](#nonbs_rows_1_2_3_6) | Sparse | float4_t | float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu) | +|[11](#nonbs_rows_1_2_3_6) | Sparse | float4_t | float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu) | +|[12](#nonbs_rows_1_2_3_6) | Sparse | float6_t | float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu) | +|[13](#nonbs_rows_4_7) | Sparse | float4_t | float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu) | +|[14](#nonbs_rows_5_8) | Sparse | float8_t | float4_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu) | +|[15](#nonbs_rows_1_2_3_6) | Sparse | float6_t | float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu) | +|[16](#nonbs_rows_4_7) | Sparse | float6_t | float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu) | +|[17](#nonbs_rows_5_8) | Sparse | float8_t | float6_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu) | +|[18](#nonbs_rows_9) | Sparse | float8_t | float8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | f8f6f4 | [TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f8_f8_f32_f16_f16_qmma.cu) | **Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs** -| | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test| -|-------------------------|-------------|-------------|----------------|-------------|-------------|-------------------------|------| -|[1](#bs_rows_1) | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)| -|[2](#bs_rows_2) | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)| -|[3](#bs_rows_3) | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)| -|[4](#bs_rows_4_5_7_8_10) | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)| -|[5](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)| -|[6](#bs_rows_6_9_11) | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)| -|[7](#bs_rows_4_5_7_8_10) | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)| -|[8](#bs_rows_4_5_7_8_10) | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)| -|[9](#bs_rows_6_9_11) | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)| -|[10](#bs_rows_4_5_7_8_10)| mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)| -|[11](#bs_rows_6_9_11) | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)| +| | Dense / Sparse | A Type | B Type | AB Layout | A Alignment | B Alignment | Target tcgen05.mma.kind |Unit Test| +|--------------------------|----------------|-------------|-------------|----------------|-------------------|-------------|-------------------------|---------| +|[1](#bs_rows_1) | Dense | nv_float4_t | nv_float4_t | TN | 32 | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/nvf4_nvf4_bf16_bf16.cu)| +|[2](#bs_rows_2) | Dense | mx_float4_t | mx_float4_t | TN | 32 | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_tn_layout.cu)| +|[3](#bs_rows_3) | Dense | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf4_void_f16_nt_layout.cu)| +|[4](#bs_rows_4_5_7_8_10) | Dense | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf6_f32_f16_nt_layout.cu)| +|[5](#bs_rows_4_5_7_8_10) | Dense | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf4_f16_f16_nt_layout.cu)| +|[6](#bs_rows_6_9_11) | Dense | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf4_mxf8_bf16_bf16_nt_layout.cu)| +|[7](#bs_rows_4_5_7_8_10) | Dense | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf4_f16_bf16_nt_layout.cu)| +|[8](#bs_rows_4_5_7_8_10) | Dense | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf6_void_bf16_nt_layout.cu)| +|[9](#bs_rows_6_9_11) | Dense | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf6_mxf8_void_f32_nt_layout.cu)| +|[10](#bs_rows_4_5_7_8_10) | Dense | mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_tn_layout.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf6_f16_f8_nt_layout.cu)| +|[11](#bs_rows_6_9_11) | Dense | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_tn_layout.cu.cu)
[NT unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/mxf8_mxf8_void_f8_nt_layout.cu)| +|[12](#bs_rows_1) | Sparse | nv_float4_t | nv_float4_t | TN | 32 (N) / 64 (T) | 32 | mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_nvf4_nvf4_f32_void_f16_o_tnn.cu) | +|[13](#bs_rows_2) | Sparse | mx_float4_t | mx_float4_t | TN | 32 (N) / 64 (T) | 32 | mxf4, mxf4nvf4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf4_f32_f16_f16_o_tnn.cu) | +|[14](#bs_rows_3) | Sparse | mx_float4_t | mx_float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf4_f32_f16_f16_q_tnt.cu) | +|[15](#bs_rows_4_5_7_8_10) | Sparse | mx_float4_t | mx_float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu) | +|[16](#bs_rows_4_5_7_8_10) | Sparse | mx_float6_t | mx_float4_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu) | +|[17](#bs_rows_6_9_11) | Sparse | mx_float4_t | mx_float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu) | +|[18](#bs_rows_4_5_7_8_10) | Sparse | mx_float8_t | mx_float4_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu) | +|[19](#bs_rows_4_5_7_8_10) | Sparse | mx_float6_t | mx_float6_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu) | +|[20](#bs_rows_6_9_11) | Sparse | mx_float6_t | mx_float8_t | TN, NN, NT, TT | 128 (N) / 256 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu) | +|[21](#bs_rows_4_5_7_8_10) | Sparse | mx_float8_t | mx_float6_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 128 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu) | +|[22](#bs_rows_6_9_11) | Sparse | mx_float8_t | mx_float8_t | TN, NN, NT, TT | 16 (N) / 32 (T) | 16 | mxf8f6f4 |[TN unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf8_f32_f16_f16_q_tnn.cu) | ## MMA tile shapes supported The alignment restrictions also limit the options for Mma Tile Shapes. Tables below list the supported/valid `MmaTileShape`, Layout, and Dispatch Policy combinations for each row of [Table 1](#legacy_gemm_table), [Table 2](#non_bs_gemm_table), and [Table 3](#bs_gemm_table). -**Table 4: Valid Tile Shapes and Dispatch Policies for lagacy types (All rows of Table 1)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|------------------|----|----|----|----|------------------------------------| -| 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x128x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x192x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x256x(4*MMA-K)| Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +**Table 4: Valid Tile Shapes and Dispatch Policies for legacy types (All rows of Table 1)** +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|--------------------|----|----|----|----|------------------------------------------| +| Dense | 1SM | 64x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 2SM | 128x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x64x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x128x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x192x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x256x(4*MMA-K) | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Sparse | 1SM | 128x64x(2/4*MMA-K) | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x128x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x192x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 2SM | 256x64x(2/4*MMA-K) | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x128x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x192x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x(2/4*MMA-K)| Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | -**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6 of Table 2)** +**Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6,10,11,12,and 15 of Table 2)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|----------------|----|----|----|----|------------------------------------| -| 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|------------------------------------------| +| Dense | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Sparse | 1SM | 128x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 2SM | 256x128x128 | N | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | -**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8 of Table 2)** +**Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8,14,and 17 of Table 2)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|----------------|----|----|----|----|------------------------------------| -| 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|------------------------------------------| +| Dense | 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Sparse | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 2SM | 256x128x128 | Y | Y | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | -**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7 of Table 2)** +**Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7,13,and 16 of Table 2)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|----------------|----|----|----|----|------------------------------------| -| 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|------------------------------------------| +| Dense | 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Sparse | 1SM | 128x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 2SM | 256x128x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x128 | N | N | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | -**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9 of Table 2)** +**Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9,18 of Table 2)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|----------------|----|----|----|----|------------------------------------| -| 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | -| 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|------------------------------------------| +| Dense | 1SM | 64x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | +| Dense | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +| Sparse | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmSm100` | +| Sparse | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +| Sparse | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmSm100` | +**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 and 12 of Table 3)** +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|----------------------------------------------| +| Dense | 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | +| Dense | 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | +| Dense | 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | +| Dense | 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | +| Dense | 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | +| Dense | 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | +| Sparse | 1SM | 128x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` | +| Sparse | 1SM | 128x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` | +| Sparse | 2SM | 256x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` | +| Sparse | 2SM | 256x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` | -**Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 of Table 3)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|---------------|----|----|----|----|----------------------------------------| -| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | -| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | -| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmNvf4Sm100` | -| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | -| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | -| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmNvf4Sm100` | +**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 and 13 of Table 3)** +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|----------------------------------------------| +| Dense | 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | +| Dense | 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | +| Dense | 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | +| Dense | 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | +| Dense | 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | +| Dense | 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | +| Sparse | 1SM | 128x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` | +| Sparse | 1SM | 128x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized1SmNvf4Sm100` | +| Sparse | 2SM | 256x128x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` | +| Sparse | 2SM | 256x256x256 | Y | N | N | N | `KernelSparseTmaWarpSpecialized2SmNvf4Sm100` | -**Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 of Table 3)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|---------------|----|----|----|----|----------------------------------------| -| 1SM | 128x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | -| 1SM | 128x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | -| 1SM | 128x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized1SmMxf4Sm100` | -| 2SM | 256x128x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | -| 2SM | 256x192x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | -| 2SM | 256x256x256 | Y | N | N | N | `KernelTmaWarpSpecialized2SmMxf4Sm100` | +**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 and 14 of Table 3)** +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|--------------------------------------------------| +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x192x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | -**Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 of Table 3)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|---------------|----|----|----|----|--------------------------------------------| -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | -| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10, 15, 16, 18, 19, and 21 of Table 3)** +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|--------------------------------------------------| +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x128x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x192x256 | Y | N | N | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | -**Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10 of Table 3)** -| 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | -|--------|---------------|----|----|----|----|--------------------------------------------| -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | -| 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | - -**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11 of Table 3)** -| 1/2 SM | Mma Tile Shape | TN| TT | NT | NN | Dispatch Policy | -|--------|---------------|----|----|----|----|--------------------------------------------| -| 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | -| 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | -| 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | -| 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +**Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11, 17, 20, and 22 of Table 3)** +| Dense / Sparse | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | +|----------------|--------|----------------|----|----|----|----|--------------------------------------------------| +| Dense | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Dense | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 1SM | 128x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x128x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x192x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | +| Sparse | 2SM | 256x256x256 | Y | Y | Y | Y | `KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100` | ## Epilogue config supported **Table 14: Epilogue Dispatch Policy** -| 1/2 SM | Epilogue Dispatch Policy | -|--------|------------------------------------------| -| 1SM | cutlass::epilogue::TmaWarpSpecialized1Sm | -| 1SM | cutlass::epilogue::NoSmemWarpSpecialized1Sm | -| 2SM | cutlass::epilogue::TmaWarpSpecialized2Sm | -| 2SM | cutlass::epilogue::NoSmemWarpSpecialized2Sm | +| Dense / Sparse | Legacy / Narrow Precision | 1/2 SM | Epilogue Dispatch Policy | +|----------------|-----------------------------|--------|----------------------------------------------------| +| Dense | Legacy & Narrow Precision | 1SM | `cutlass::epilogue::TmaWarpSpecialized1Sm` | +| Dense | Legacy & Narrow Precision | 1SM | `cutlass::epilogue::NoSmemWarpSpecialized1Sm` | +| Dense | Legacy & Narrow Precision | 2SM | `cutlass::epilogue::TmaWarpSpecialized2Sm` | +| Dense | Legacy & Narrow Precision | 2SM | `cutlass::epilogue::NoSmemWarpSpecialized2Sm` | +| Sparse | Legacy | 1SM | `cutlass::epilogue::TmaWarpSpecialized1Sm` | +| Sparse | Legacy | 1SM | `cutlass::epilogue::NoSmemWarpSpecialized1Sm` | +| Sparse | Legacy | 2SM | `cutlass::epilogue::TmaWarpSpecialized2Sm` | +| Sparse | Legacy | 2SM | `cutlass::epilogue::NoSmemWarpSpecialized2Sm` | +| Sparse | Narrow Precision (nvf4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmNvf4` | +| Sparse | Narrow Precision (nvf4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmNvf4` | +| Sparse | Narrow Precision (mxf4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmMxf4` | +| Sparse | Narrow Precision (mxf4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmMxf4` | +| Sparse | Narrow Precision (mxf8f6f4) | 1SM | `cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4` | +| Sparse | Narrow Precision (mxf8f6f4) | 2SM | `cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4` | **Table 15: Epilogue PerSmTileShape_MNK** | 1/2 SM | MMA tile Shape | PerSmTileShape_MNK | @@ -314,14 +414,16 @@ MMA_TileShape_K is is generally 4 * MMA-Instruction-K. It depends on the config ### Auto Kernel Dispatch Policies In addition to direct dispatch policies listed above, the user can also use auto policies for both non-block scaled narrow-precision -GEMMs, and block scaled narrow-precision GEMMs. +GEMMs (both sparse and dense), and block scaled narrow-precision GEMMs (only dense). CUTLASS will do its best to find the most efficient kernel for given parameters, however, the preferred method for building these kernels is to use direct kernel dispatch policies shown in the above tables. -* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma`. +* `cutlass::gemm::collective::KernelScheduleAuto`: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM `tcgen05.mma(.sp)`. * `KernelTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically. * `KernelTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically. +* `KernelSparseTmaWarpSpecialized1SmBlockScaledSm100`: Use 1 SM `tcgen05.mma.sp` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically. +* `KernelSparseTmaWarpSpecialized2SmBlockScaledSm100`: Use 2 SM `tcgen05.mma.sp` instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically. Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueScheduleAuto`. @@ -330,16 +432,23 @@ Similarly for epilogues, we can use `cutlass::epilogue::collective::EpilogueSche For non-blockscaled dense GEMM refer to [quick start page](quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel). An example dense GEMM can be found: 1. [Blackwell FP16 GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/70_blackwell_gemm/). +An example sparse GEMM can be found: +1. [Blackwell FP16 Sparse GEMM example](https://github.com/NVIDIA/cutlass/tree/main/examples/83_blackwell_sparse_gemm/). + Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface (as described in [CUTLASS 3.0 GEMM API](gemm_api_3x.md#cutlass-30-gemm-api)). However, special attention needs to be given to A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel which are listed above. -Several examples of block scaled kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory: +Several examples of block scaled dense GEMM kernels can be found in [examples/72_blackwell_narrow_precision_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/) directory: 1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu) 2. [NVF4 Gemm with block scaling and NVF4 output matrix](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu) 3. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu) +Several examples of block scaled sparse GEMM kernels can be found in [examples/84_blackwell_narrow_precision_sparse_gemm](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm) directory: +1. [NVF4 Gemm with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu) +2. [Mixed precision Nvf4 x Mxf8 GEMM with block scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu) + Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described [here](gemm_api_3x.md#collective-builder-for-collectivemmas) with a small difference for Collective MMA builder interface. As in all Blackwell kernels, the `TileShape_MNK` argument expects the `MmaTileShape_MNK` which is the tile shape needed diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 676b111e..5ee68881 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -28,7 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - #include #include "cutlass_unit_test.h" @@ -59,7 +58,10 @@ std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &deviceProperti int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor; if (deviceMajorMinor) { - int32_t clock_MHz = deviceProperties.clockRate / 1000; + int32_t clock_MHz; + int32_t clock_KHz; + cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0); + clock_MHz = clock_KHz / 1000; out << "GPU(compute_" << deviceMajorMinor << ", " << deviceProperties.multiProcessorCount << " SMs @ " << clock_MHz << " MHz)"; diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/CMakeLists.txt index 735b3760..5575fe98 100644 --- a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/CMakeLists.txt +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/CMakeLists.txt @@ -29,22 +29,25 @@ if (CUTLASS_NVCC_ARCHS MATCHES 100a) add_custom_target( - cutlass_test_unit_gemm_device_sm100_bssp + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse DEPENDS - cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o - cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o - cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q - cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q - cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o - cutlass_test_unit_gemm_device_sm100_bssp_streamk + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o BATCH_SOURCES ON BATCH_SIZE 1 @@ -57,7 +60,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o BATCH_SOURCES ON BATCH_SIZE 1 @@ -70,7 +73,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o BATCH_SOURCES ON BATCH_SIZE 1 @@ -83,7 +86,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q BATCH_SOURCES ON BATCH_SIZE 1 @@ -96,7 +99,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q BATCH_SOURCES ON BATCH_SIZE 1 @@ -109,7 +112,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q BATCH_SOURCES ON BATCH_SIZE 1 @@ -127,7 +130,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o BATCH_SOURCES ON BATCH_SIZE 1 @@ -140,7 +143,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q BATCH_SOURCES ON BATCH_SIZE 1 @@ -148,10 +151,32 @@ cutlass_test_unit_gemm_device_add_executable_split_file( sm100_bssp_gemm_mxf8_mxf4_f32_f16_mxf8_q_tnt.cu sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu sm100_bssp_gemm_mxf8_mxf4_f32_f32_f32_q_tnt.cu + + sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q + + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu + sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q + + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu + sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q BATCH_SOURCES ON BATCH_SIZE 1 @@ -162,7 +187,16 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_bssp_streamk + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q + + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk BATCH_SOURCES ON BATCH_SIZE 1 diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu new file mode 100644 index 00000000..c9660746 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt_vs64in +// 2. 128x192_tnt_vs64in +// 3. 128x256_tnt_vs64in +// 4. 256x128_tnt_vs64in +// 5. 256x192_tnt_vs64in +// 6. 256x256_tnt_vs64in + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe2m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu new file mode 100644 index 00000000..2d1f7fe2 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt_vs64in +// 2. 128x192_tnt_vs64in +// 3. 128x256_tnt_vs64in +// 4. 256x128_tnt_vs64in +// 5. 256x192_tnt_vs64in +// 6. 256x256_tnt_vs64in + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m1_ue8m0xe4m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu new file mode 100644 index 00000000..3a2b8fb0 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt_vs64in +// 2. 128x192_tnt_vs64in +// 3. 128x256_tnt_vs64in +// 4. 256x128_tnt_vs64in +// 5. 256x192_tnt_vs64in +// 6. 256x256_tnt_vs64in + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float4_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe2m1_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu new file mode 100644 index 00000000..3645ef36 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt_vs64in +// 2. 128x192_tnt_vs64in +// 3. 128x256_tnt_vs64in +// 4. 256x128_tnt_vs64in +// 5. 256x192_tnt_vs64in +// 6. 256x256_tnt_vs64in + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe3m2_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu new file mode 100644 index 00000000..d3850ca0 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt_vs64in +// 2. 128x192_tnt_vs64in +// 3. 128x256_tnt_vs64in +// 4. 256x128_tnt_vs64in +// 5. 256x192_tnt_vs64in +// 6. 256x256_tnt_vs64in + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float6_t; + using ElementPairB = cutlass::mx_float8_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe2m3_ue8m0xe4m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif diff --git a/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu new file mode 100644 index 00000000..e3040f86 --- /dev/null +++ b/test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm/sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt_vs64in +// 2. 128x192_tnt_vs64in +// 3. 128x256_tnt_vs64in +// 4. 256x128_tnt_vs64in +// 5. 256x192_tnt_vs64in +// 6. 256x256_tnt_vs64in + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = float; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_f16_f16_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 1. +namespace cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 5. +namespace cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _192, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 6. +namespace cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementPairA = cutlass::mx_float8_t; + using ElementPairB = cutlass::mx_float6_t; + using ElementC = void; + using ElementD = float; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassTag = cutlass::arch::OpClassBlockScaledSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2SmMxf8f6f4; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmMxf8f6f4Sm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = float; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType, + cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Clamp, + ElementD, + ElementEpilogueCompute, + ElementBias, + ElementC, + ElementEpilogueCompute> + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassTag, + ElementPairA, LayoutA, kAlignmentA, + ElementPairB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x128x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x192x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s128x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_128x256x256_0_vs64_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x128x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x128x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 5. +TEST(cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x192x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x192x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 6. +TEST(cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_bssptensorop_s256x256x64bsspgemm_ue8m0xe4m3_ue8m0xe2m3_f32_void_f32_256x256x256_0_vs64_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/CMakeLists.txt index 334cdce4..f61bf8ad 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/CMakeLists.txt +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/CMakeLists.txt @@ -26,18 +26,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +add_subdirectory(narrow_precision) + if (CUTLASS_NVCC_ARCHS MATCHES 100a) add_custom_target( - cutlass_test_unit_gemm_device_sm100_sp + cutlass_test_unit_gemm_device_sm100_sparse DEPENDS - cutlass_test_unit_gemm_device_sm100_sp_general - cutlass_test_unit_gemm_device_sm100_sp_qmma_variance - cutlass_test_unit_gemm_device_sm100_sp_streamk + cutlass_test_unit_gemm_device_sm100_sparse_general + cutlass_test_unit_gemm_device_sm100_sparse_streamk ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_sp_general + cutlass_test_unit_gemm_device_sm100_sparse_general # No batching of source to control compiler memory usage BATCH_SOURCES ON @@ -52,23 +53,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file( ) cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_sp_qmma_variance - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu - sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu - sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu - - sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu - sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu - sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu -) - -cutlass_test_unit_gemm_device_add_executable_split_file( - cutlass_test_unit_gemm_device_sm100_sp_streamk + cutlass_test_unit_gemm_device_sm100_sparse_streamk # No batching of source to control compiler memory usage BATCH_SOURCES ON diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/CMakeLists.txt b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/CMakeLists.txt new file mode 100644 index 00000000..dc1549ec --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/CMakeLists.txt @@ -0,0 +1,77 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +add_custom_target( + cutlass_test_unit_gemm_device_sm100_sparse_narrow_precision + DEPENDS + cutlass_test_unit_gemm_device_sm100_sparse_f4xf4 + cutlass_test_unit_gemm_device_sm100_sparse_f6xf6 + cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6 + cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8 + cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8 +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_sparse_f4xf4 + sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu + sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu + sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_sparse_f6xf6 + + sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu + sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu + sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6 + + sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu + sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8 + + sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu + sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8 + + sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu + sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu +) + +endif() diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu similarity index 99% rename from test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu rename to test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu index 7819c33f..e46934c8 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu @@ -40,8 +40,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/thread/activation.h" -#include "../../../common/cutlass_unit_test.h" -#include "../gemm_testbed_3x.hpp" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" using namespace cute; diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu similarity index 99% rename from test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu rename to test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu index b3a61011..dbce7f5d 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu @@ -40,8 +40,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/thread/activation.h" -#include "../../../common/cutlass_unit_test.h" -#include "../gemm_testbed_3x.hpp" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" using namespace cute; diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu similarity index 99% rename from test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu rename to test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu index 91ac9c12..94f76ca0 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu @@ -40,8 +40,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/thread/activation.h" -#include "../../../common/cutlass_unit_test.h" -#include "../gemm_testbed_3x.hpp" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" using namespace cute; diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu new file mode 100644 index 00000000..1896e27d --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt +// 2. 128x256_tnt +// 3. 256x128_tnt +// 4. 256x256_tnt + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu new file mode 100644 index 00000000..4960e6a7 --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt +// 2. 128x256_tnt +// 3. 256x128_tnt +// 4. 256x256_tnt + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu new file mode 100644 index 00000000..9643370a --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt +// 2. 128x256_tnt +// 3. 256x128_tnt +// 4. 256x256_tnt + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu similarity index 99% rename from test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu rename to test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu index e89f844c..94b52d60 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu @@ -40,8 +40,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/thread/activation.h" -#include "../../../common/cutlass_unit_test.h" -#include "../gemm_testbed_3x.hpp" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" using namespace cute; diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu similarity index 99% rename from test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu rename to test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu index 1f3cc952..c8ab01da 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu @@ -40,8 +40,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/thread/activation.h" -#include "../../../common/cutlass_unit_test.h" -#include "../gemm_testbed_3x.hpp" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" using namespace cute; diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu similarity index 99% rename from test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu rename to test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu index 602b07ca..f97911be 100644 --- a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu @@ -40,8 +40,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/thread/activation.h" -#include "../../../common/cutlass_unit_test.h" -#include "../gemm_testbed_3x.hpp" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" using namespace cute; diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu new file mode 100644 index 00000000..ad81c5f8 --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt +// 2. 128x256_tnt +// 3. 256x128_tnt +// 4. 256x256_tnt + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e3m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 256; + constexpr int kAlignmentB = 16; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu new file mode 100644 index 00000000..4b0ab7f8 --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt +// 2. 128x256_tnt +// 3. 256x128_tnt +// 4. 256x256_tnt + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu new file mode 100644 index 00000000..adf4188e --- /dev/null +++ b/test/unit/gemm/device/sm100_sparse_tensorop_gemm/narrow_precision/sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +// * Test list +// 1. 128x128_tnt +// 2. 128x256_tnt +// 3. 256x128_tnt +// 4. 256x256_tnt + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 1, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + + +// 1. +namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 2. +namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_128, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 3. +namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _128, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 4. +namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e3m2_t; + using ElementC = void; + using ElementD = cutlass::half_t; + + constexpr int kAlignmentA = 32; + constexpr int kAlignmentB = 128; + constexpr int kAlignmentD = 128 / cutlass::sizeof_bits::value; + constexpr int kAlignmentC = cute::is_same_v ? kAlignmentD : 128 / cutlass::sizeof_bits::value; + + using ProblemShape = Shape; + using ClusterShape = cute::Shape; + using MmaTileShape = Shape<_256, _256, _256>; + using ArchTag = cutlass::arch::Sm100; + using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp; + using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100; + using ElementAccumulator = float; + using ElementEpilogueCompute = float; + using ElementBias = cutlass::half_t; + using TileScheduler = cutlass::gemm::PersistentScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OpClassEpilogue, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementEpilogueCompute, + ElementC, LayoutC, kAlignmentC, + ElementD, LayoutD, kAlignmentD, + EpilogueScheduleType + >::CollectiveOp; + + using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OpClassMainLoop, + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + StageCount, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +} + +// 1. +TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 2. +TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 3. +TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +// 4. +TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) { + namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm; + EXPECT_TRUE(test::gemm::device::TestSmall( + 1, 0, + test::gemm::device::CheckEquality::RELATIVE, + test::gemm::device::ScalarLoc::ON_DEVICE, + test::gemm::device::VectorScale::ENABLED, + {256, 2560})); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 0573bd9d..4b27c64f 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -31,7 +31,8 @@ /* \file \brief Command line options for performance test program */ - +#include +#include #include #include #include @@ -165,9 +166,11 @@ void Options::Device::print_usage(std::ostream &out) const { break; } else { + int32_t clock_KHz; + cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0); out << " [" << idx << "] - " << prop.name << " - SM " << prop.major << "." << prop.minor << ", " - << prop.multiProcessorCount << " SMs @ " << (prop.clockRate / 1000.0) << " MHz, " + << prop.multiProcessorCount << " SMs @ " << (clock_KHz / 1000.0) << " MHz, " << "L2 cache: " << (prop.l2CacheSize >> 20) << " MB, Global Memory: " << (prop.totalGlobalMem >> 30) << " GB" << std::endl; } @@ -216,9 +219,11 @@ void Options::Device::print_options(std::ostream &out, int indent) const { for (int device : devices) { out << device << ','; } + int32_t clock_KHz; + cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0); out << "\n" - << indent_str(indent) << "clock: " << int(double(properties[0].clockRate) / 1000.0) << "\n" + << indent_str(indent) << "clock: " << int(double(clock_KHz) / 1000.0) << "\n" << indent_str(indent) << "compute-capability: " << compute_capability(0) << "\n"; } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 4347bcac..1999730f 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -109,7 +109,8 @@ bool BlockCompareEqual( Element const *ptr_B, size_t capacity, int grid_size = 0, - int block_size = 0) { + int block_size = 0, + cudaStream_t stream = nullptr) { int equal_flag = 1; int *device_equal_flag = nullptr; @@ -146,7 +147,9 @@ bool BlockCompareEqual( dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); - kernel::BlockCompareEqual<<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity); + kernel::BlockCompareEqual<<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity); + + cudaStreamSynchronize(stream); if (cudaMemcpy( &equal_flag, @@ -175,7 +178,8 @@ bool BlockCompareRelativelyEqual( Element epsilon, Element nonzero_floor, int grid_size = 0, - int block_size = 0) { + int block_size = 0, + cudaStream_t stream = nullptr) { int equal_flag = 1; int *device_equal_flag = nullptr; @@ -212,7 +216,7 @@ bool BlockCompareRelativelyEqual( dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); - kernel::BlockCompareRelativelyEqual<<< grid, block >>>( + kernel::BlockCompareRelativelyEqual<<< grid, block, 0, stream >>>( device_equal_flag, ptr_A, ptr_B, @@ -221,6 +225,8 @@ bool BlockCompareRelativelyEqual( nonzero_floor ); + cudaStreamSynchronize(stream); + if (cudaMemcpy( &equal_flag, device_equal_flag, diff --git a/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/tools/util/include/cutlass/util/reference/device/tensor_reduce.h index c210d533..3e6d7b30 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_reduce.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_reduce.h @@ -232,6 +232,8 @@ ComputeType TensorTransformReduce( workspace, identity, workspace_size, reduce ); + cudaStreamSynchronize(stream); + if (copy_out) { cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); if (result != cudaSuccess) { @@ -285,6 +287,8 @@ ComputeType TensorTransformReduce( workspace, identity, workspace_size, reduce ); + cudaStreamSynchronize(stream); + if (copy_out) { cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); if (result != cudaSuccess) {