@ -28,7 +28,6 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
@ -59,7 +58,10 @@ std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &deviceProperti
|
||||
|
||||
int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor;
|
||||
if (deviceMajorMinor) {
|
||||
int32_t clock_MHz = deviceProperties.clockRate / 1000;
|
||||
int32_t clock_MHz;
|
||||
int32_t clock_KHz;
|
||||
cudaDeviceGetAttribute(&clock_KHz, cudaDevAttrClockRate, 0);
|
||||
clock_MHz = clock_KHz / 1000;
|
||||
out << "GPU(compute_"
|
||||
<< deviceMajorMinor << ", "
|
||||
<< deviceProperties.multiProcessorCount << " SMs @ " << clock_MHz << " MHz)";
|
||||
|
||||
@ -29,22 +29,25 @@
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f32_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f32_f32_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -57,7 +60,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_f16_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_f16_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -70,7 +73,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_nvf4_nvf4_f32_f16_nvf4_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_nvf4_nvf4_f32_f16_nvf4_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -83,7 +86,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f32_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f32_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -96,7 +99,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_f16_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_f16_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -109,7 +112,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf8_f32_f16_mxf8_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8_mxf8_f32_f16_mxf8_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -127,7 +130,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_o
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_o
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -140,7 +143,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf8_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf4_mxf4mxf8_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -148,10 +151,32 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
sm100_bssp_gemm_mxf8_mxf4_f32_f16_mxf8_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf8_mxf4_f32_f16_f16_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf8_mxf4_f32_f32_f32_q_tnt.cu
|
||||
|
||||
sm100_bssp_gemm_mxf4_mxf8_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_mxf4_mxf4_f32_q
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf8mxf6_mxf6mxf8_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_bssp_gemm_mxf6_mxf8_f32_f16_f16_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf8_mxf6_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4mxf6_mxf4mxf6_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_bssp_gemm_mxf4_mxf6_f32_f16_f16_q_tnt.cu
|
||||
sm100_bssp_gemm_mxf6_mxf4_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf4_mxf4_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
@ -162,7 +187,16 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_bssp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_mxf6_mxf6_f32_q
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_bssp_gemm_mxf6_mxf6_f32_f16_f16_q_tnt.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_blockscaled_sparse_streamk
|
||||
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -26,18 +26,19 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_subdirectory(narrow_precision)
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm100_sp
|
||||
cutlass_test_unit_gemm_device_sm100_sparse
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_sm100_sp_general
|
||||
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
|
||||
cutlass_test_unit_gemm_device_sm100_sp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_general
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_streamk
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sp_general
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_general
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
@ -52,23 +53,7 @@ cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sp_qmma_variance
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
BATCH_SIZE 1
|
||||
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f8_qmma.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f16_qmma.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f32_f32_qmma.cu
|
||||
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f8_qmma.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f16_qmma.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f32_f32_qmma.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sp_streamk
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_streamk
|
||||
|
||||
# No batching of source to control compiler memory usage
|
||||
BATCH_SOURCES ON
|
||||
|
||||
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_narrow_precision
|
||||
DEPENDS
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4xf4
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f8_tn.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f4_f4_f32_f32_f32_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6xf6
|
||||
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f8_tn.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f6_f6_f32_f32_f32_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f6xf4f6
|
||||
|
||||
sm100_sp_gemm_f4_f6_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f6_f4_f32_f16_f16_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f4f8xf4f8
|
||||
|
||||
sm100_sp_gemm_f4_f8_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f8_f4_f32_f16_f16_tn.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_gemm_device_add_executable_split_file(
|
||||
cutlass_test_unit_gemm_device_sm100_sparse_f6f8xf6f8
|
||||
|
||||
sm100_sp_gemm_f6_f8_f32_f16_f16_tn.cu
|
||||
sm100_sp_gemm_f8_f6_f32_f16_f16_tn.cu
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e2m1_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e2m1_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e2m1_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e2m1_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e2m1_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -40,8 +40,8 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../common/cutlass_unit_test.h"
|
||||
#include "../gemm_testbed_3x.hpp"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e3m2_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 256;
|
||||
constexpr int kAlignmentB = 16;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e3m2_e4m3_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e3m2_e4m3_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e3m2_e4m3_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e3m2_e4m3_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e2m1_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e2m1_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e2m1_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e2m1_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e2m1_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "../../../../common/cutlass_unit_test.h"
|
||||
#include "../../gemm_testbed_3x.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// * Test list
|
||||
// 1. 128x128_tnt
|
||||
// 2. 128x256_tnt
|
||||
// 3. 256x128_tnt
|
||||
// 4. 256x256_tnt
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_f16_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_f16_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_f16_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_f16_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 1,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
|
||||
// 1.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 2.
|
||||
namespace cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 3.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _128, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 4.
|
||||
namespace cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e3m2_t;
|
||||
using ElementC = void;
|
||||
using ElementD = cutlass::half_t;
|
||||
|
||||
constexpr int kAlignmentA = 32;
|
||||
constexpr int kAlignmentB = 128;
|
||||
constexpr int kAlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr int kAlignmentC = cute::is_same_v<ElementC, void> ? kAlignmentD : 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OpClassEpilogue = cutlass::arch::OpClassSparseTensorOp;
|
||||
using OpClassMainLoop = cutlass::arch::OpClassSparseTensorOp;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using KernelScheduleType = cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogueCompute = float;
|
||||
using ElementBias = cutlass::half_t;
|
||||
using TileScheduler = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassEpilogue,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
ElementAccumulator,
|
||||
ElementEpilogueCompute,
|
||||
ElementC, LayoutC, kAlignmentC,
|
||||
ElementD, LayoutD, kAlignmentD,
|
||||
EpilogueScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCount = cutlass::gemm::collective::StageCountAutoCarveoutEpi<CollectiveEpilogue>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OpClassMainLoop,
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
StageCount,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
}
|
||||
|
||||
// 1.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x128x64spgemm_e4m3_e3m2_f32_void_f16_128x128x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 2.
|
||||
TEST(cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s128x256x64spgemm_e4m3_e3m2_f32_void_f16_128x256x256_0_tnt_align32_q_1sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 3.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x128x64spgemm_e4m3_e3m2_f32_void_f16_256x128x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
// 4.
|
||||
TEST(cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm, functional) {
|
||||
namespace gemm = cutlass3x_sm100_sptensorop_s256x256x64spgemm_e4m3_e3m2_f32_void_f16_256x256x256_0_tnt_align32_q_2sm;
|
||||
EXPECT_TRUE(test::gemm::device::TestSmall<gemm::Gemm>(
|
||||
1, 0,
|
||||
test::gemm::device::CheckEquality::RELATIVE,
|
||||
test::gemm::device::ScalarLoc::ON_DEVICE,
|
||||
test::gemm::device::VectorScale::ENABLED,
|
||||
{256, 2560}));
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
Reference in New Issue
Block a user