v3.9 update (#2213)

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-02 23:10:16 -07:00
committed by GitHub
parent 6f4921858b
commit 79fc51f4b8
72 changed files with 19875 additions and 459 deletions

View File

@ -28,7 +28,6 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cuda_runtime_api.h>
#include "cutlass_unit_test.h"
@ -59,7 +58,10 @@ std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &deviceProperti
int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor;
if (deviceMajorMinor) {
int32_t clock_MHz = deviceProperties.clockRate / 1000;
int32_t clock_MHz;
int32_t clock_KHz;
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
clock_MHz = clock_KHz / 1000;
out << "GPU(compute_"
<< deviceMajorMinor << ", "
<< deviceProperties.multiProcessorCount << " SMs @ " << clock_MHz << " MHz)";

View File

@ -29,22 +29,25 @@
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
add_custom_target(
cutlass_test_unit_gemm_device_sm100_bssp
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse
DEPENDS
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
cutlass_test_unit_gemm_device_sm100_bssp_streamk
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
BATCH_SOURCES ON
BATCH_SIZE 1
@ -57,7 +60,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
BATCH_SOURCES ON
BATCH_SIZE 1
@ -70,7 +73,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
BATCH_SOURCES ON
BATCH_SIZE 1
@ -83,7 +86,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
@ -96,7 +99,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
BATCH_SOURCES ON
BATCH_SIZE 1
@ -109,7 +112,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
BATCH_SOURCES ON
BATCH_SIZE 1
@ -127,7 +130,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
BATCH_SOURCES ON
BATCH_SIZE 1
@ -140,7 +143,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
@ -148,10 +151,32 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
sm100_bssp_gemm_mxf8_mxf4_f32_f16_mxf8_q_tnt.cu
sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu
sm100_bssp_gemm_mxf8_mxf4_f32_f32_f32_q_tnt.cu
sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu
sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu
sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
@ -162,7 +187,16 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_bssp_streamk
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
BATCH_SOURCES ON
BATCH_SIZE 1

View File

@ -26,18 +26,19 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
add_subdirectory(narrow_precision)
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
add_custom_target(
cutlass_test_unit_gemm_device_sm100_sp
cutlass_test_unit_gemm_device_sm100_sparse
DEPENDS
cutlass_test_unit_gemm_device_sm100_sp_general
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
cutlass_test_unit_gemm_device_sm100_sp_streamk
cutlass_test_unit_gemm_device_sm100_sparse_general
cutlass_test_unit_gemm_device_sm100_sparse_streamk
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sp_general
cutlass_test_unit_gemm_device_sm100_sparse_general
# No batching of source to control compiler memory usage
BATCH_SOURCES ON
@ -52,23 +53,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
# No batching of source to control compiler memory usage
BATCH_SOURCES ON
BATCH_SIZE 1
sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu
sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu
sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu
sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu
sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu
sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sp_streamk
cutlass_test_unit_gemm_device_sm100_sparse_streamk
# No batching of source to control compiler memory usage
BATCH_SOURCES ON

View File

@ -0,0 +1,77 @@
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
add_custom_target(
cutlass_test_unit_gemm_device_sm100_sparse_narrow_precision
DEPENDS
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu
sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu
sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu
sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu
sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu
sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu
sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu
)
cutlass_test_unit_gemm_device_add_executable_split_file(
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu
sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu
)
endif()

View File

@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e2m1_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@ -40,8 +40,8 @@
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../common/cutlass_unit_test.h"
#include "../gemm_testbed_3x.hpp"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;

View File

@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e3m2_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 256;
constexpr int kAlignmentB = 16;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e2m1_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

View File

@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "../../../../common/cutlass_unit_test.h"
#include "../../gemm_testbed_3x.hpp"
using namespace cute;
// * Test list
// 1. 128x128_tnt
// 2. 128x256_tnt
// 3. 256x128_tnt
// 4. 256x256_tnt
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 1,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 1.
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 2.
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using MmaTileShape = Shape<_128, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 3.
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _128, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 4.
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e3m2_t;
using ElementC = void;
using ElementD = cutlass::half_t;
constexpr int kAlignmentA = 32;
constexpr int kAlignmentB = 128;
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
using ProblemShape = Shape<int,int,int,int>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using MmaTileShape = Shape<_256, _256, _256>;
using ArchTag = cutlass::arch::Sm100;
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
using ElementAccumulator = float;
using ElementEpilogueCompute = float;
using ElementBias = cutlass::half_t;
using TileScheduler = cutlass::gemm::PersistentScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OpClassEpilogue,
MmaTileShape,
ClusterShape,
EpilogueTile,
ElementAccumulator,
ElementEpilogueCompute,
ElementC, LayoutC, kAlignmentC,
ElementD, LayoutD, kAlignmentD,
EpilogueScheduleType
>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OpClassMainLoop,
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
StageCount,
KernelScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}
// 1.
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 2.
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 3.
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
// 4.
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
1, 0,
test::gemm::device::CheckEquality::RELATIVE,
test::gemm::device::ScalarLoc::ON_DEVICE,
test::gemm::device::VectorScale::ENABLED,
{256, 2560}));
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)