CUTLASS 3.6.0 (#1850)

* v3.6

* update changelog

* update readme

* fix typo

* fixing typos

* hopper gemm with weight prefetch

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2024-10-09 12:33:27 -07:00
committed by GitHub
parent 0837a2a00a
commit cc3c29a81a
354 changed files with 105943 additions and 8203 deletions

View File

@ -0,0 +1,58 @@
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Compress Kernel
#
add_custom_target(
cutlass_test_unit_sm90_structured_sparse_gemm_compressor
DEPENDS
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f32
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f16
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f8
)
cutlass_test_unit_add_executable(
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f32
sm90_sparse_gemm_compressor_f32.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f16
sm90_sparse_gemm_compressor_f16.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f8
sm90_sparse_gemm_compressor_f8.cu
)

View File

@ -0,0 +1,95 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major
#include "cutlass/arch/config.h" // CUTLASS_ARCH_MMA_SM90_SUPPORTED
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
#include "cutlass/gemm/collective/builders/sm90_common.inl" // gmma_ss_tag_to_major_A
#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" // Sm90GemmSparseConfig
#include "testbed_sparse_gemm_compressor.hpp" // TestbedSparseGemmCompressor
///////////////////////////////////////////////////////////////////////////////////////////////////
// * Test Plan
// ElementA : fp16
// LayoutA : row / col
// Gemm : 1x 2x 3x multiplier of alignment requirement. corner case that smaller than alignment requirement
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f16_t)
{
// Test Settings
using ElementA = cutlass::half_t;
using LayoutATag = cutlass::layout::RowMajor;
// Deduct From Test Setting
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
using ElementAMma = cute::sparse_elem<2, ElementA>;
using ElementEMma = cute::sparse_elem<8, uint8_t>;
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<32>>;
using CompressorKernel = cutlass::transform::kernel::
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
// Test Bed
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
EXPECT_TRUE(testbed.run_auto());
}
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f16_n)
{
// Test Settings
using ElementA = cutlass::bfloat16_t;
using LayoutATag = cutlass::layout::ColumnMajor;
// Deduct From Test Setting
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
using ElementAMma = cute::sparse_elem<2, ElementA>;
using ElementEMma = cute::sparse_elem<8, uint8_t>;
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<64>>;
using CompressorKernel = cutlass::transform::kernel::
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
// Test Bed
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
EXPECT_TRUE(testbed.run_auto());
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,95 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major
#include "cutlass/arch/config.h" // CUTLASS_ARCH_MMA_SM90_SUPPORTED
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
#include "cutlass/gemm/collective/builders/sm90_common.inl" // gmma_ss_tag_to_major_A
#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" // Sm90GemmSparseConfig
#include "testbed_sparse_gemm_compressor.hpp" // TestbedSparseGemmCompressor
///////////////////////////////////////////////////////////////////////////////////////////////////
// * Test Plan
// ElementA : fp32
// LayoutA : row / col
// Gemm : 1x 2x 3x multiplier of alignment requirement. corner case that smaller than alignment requirement
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f32_t)
{
// Test Settings
using ElementA = float;
using LayoutATag = cutlass::layout::RowMajor;
// Deduct From Test Setting
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
using ElementAMma = cute::sparse_elem<2, ElementA>;
using ElementEMma = cute::sparse_elem<4, uint8_t>;
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<16>>;
using CompressorKernel = cutlass::transform::kernel::
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
// Test Bed
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
EXPECT_TRUE(testbed.run_auto());
}
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f32_n)
{
// Test Settings
using ElementA = cutlass::tfloat32_t;
using LayoutATag = cutlass::layout::ColumnMajor;
// Deduct From Test Setting
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
using ElementAMma = cute::sparse_elem<2, ElementA>;
using ElementEMma = cute::sparse_elem<4, uint8_t>;
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<32>>;
using CompressorKernel = cutlass::transform::kernel::
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
// Test Bed
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
EXPECT_TRUE(testbed.run_auto());
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,95 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major
#include "cutlass/arch/config.h" // CUTLASS_ARCH_MMA_SM90_SUPPORTED
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
#include "cutlass/gemm/collective/builders/sm90_common.inl" // gmma_ss_tag_to_major_A
#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" // Sm90GemmSparseConfig
#include "testbed_sparse_gemm_compressor.hpp" // TestbedSparseGemmCompressor
///////////////////////////////////////////////////////////////////////////////////////////////////
// * Test Plan
// ElementA : fp8
// LayoutA : row / col
// Gemm : 1x 2x 3x multiplier of alignment requirement. corner case that smaller than alignment requirement
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f8_t)
{
// Test Settings
using ElementA = cutlass::float_e4m3_t;
using LayoutATag = cutlass::layout::RowMajor;
// Deduct From Test Setting
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
using ElementAMma = cute::sparse_elem<2, ElementA>;
using ElementEMma = cute::sparse_elem<8, uint8_t>;
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<64>>;
using CompressorKernel = cutlass::transform::kernel::
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
// Test Bed
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
EXPECT_TRUE(testbed.run_auto());
}
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f8_n)
{
// Test Settings
using ElementA = cutlass::float_e5m2_t;
using LayoutATag = cutlass::layout::ColumnMajor;
// Deduct From Test Setting
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
using ElementAMma = cute::sparse_elem<2, ElementA>;
using ElementEMma = cute::sparse_elem<8, uint8_t>;
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<64>>;
using CompressorKernel = cutlass::transform::kernel::
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
// Test Bed
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
EXPECT_TRUE(testbed.run_auto());
}
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

View File

@ -0,0 +1,480 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Compress utils specific for SM90 structure sparse kernels
*/
#pragma once
#include <algorithm> // std::fill
#include <array> // std::array
#include <cstdio>
#include <random> // std::mt19937
#include "cute/container/bit_field.hpp" // cute::bit_field
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor
#include "cutlass/arch/arch.h" // cutlass::arch::Sm90
#include "cutlass/cutlass.h" // cutlass::Status
#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
namespace cutlass
{
namespace transform
{
namespace kernel
{
using namespace cute;
namespace detail {
template<typename T>
CUTLASS_HOST_DEVICE
static uint8_t
encode_in_chunk_idx_legacy(int in_chunk_idx){
if (sizeof(T) == 4) {
return in_chunk_idx == 0 ? 0b0100 : 0b1110;
}
else {
uint8_t res = 0;
if (in_chunk_idx == 0) {
res = 0b00;
}
else if (in_chunk_idx == 1) {
res = 0b01;
}
else if (in_chunk_idx == 2) {
res = 0b10;
}
else {
res = 0b11;
}
return res;
}
}
template <
class SparseConfig,
class EngineA,
class LayoutA,
class EngineAc,
class LayoutAc
>
CUTLASS_HOST_DEVICE
static void
compress_two_chunks_legacy(
Tensor<EngineA, LayoutA> tensorA,
Tensor<EngineAc, LayoutAc> tensorAc,
uint8_t& meta_two_chunk,
int effective_elems) {
using ElementA = typename EngineAc::value_type;
static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{};
static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
/*
Legal metadata chunk in SM90
Index Bin HEX
0, 1 0b0100 4
1, 2 0b1001 9
2, 3 0b1110 E
0, 2 0b1000 8
1, 3 0b1101 D
0, 3 0b1100 C
2, 1 0b0110 6 (Not used)
-----------------------------------
TF32
0 0b0100 4
1 0b1110 E
*/
if (effective_elems <= 0) {
return;
}
// initialize
// 0 is the initial value for this function while 0x44 is the initial value for hardware.
meta_two_chunk = 0;
for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) {
// If Only One Chunk within this Two Chunk
if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) {
break;
}
/// init result;
int non_zero_cnt = 0;
int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 };
ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} };
for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) {
bool is_nz = true;
ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} };
/// Check if subchunk is non-zero
for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) {
int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx;
subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0);
if (subchunk_elems[elem_idx] != ElementA(0)) {
if (non_zero_cnt >= PhysicalSubChunk) {
#ifdef __CUDA_ARCH__
asm volatile ("brkpt;\n" ::);
#else
throw std::runtime_error("Found extra non-zero elements in a chunk!\n");
#endif
}
is_nz = false;
}
}
/// There is non-zero element in the subchunk
if(!is_nz) {
nnz_chunk_idx[non_zero_cnt] = subchunk_idx;
memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
non_zero_cnt++;
}
}
/*
Special cases
nnz == 1 and non-tf32 and nnz_idx = 3
*/
ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} };
if constexpr (sizeof_bits_v<ElementA> < 32) {
if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) {
memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
nnz_chunk_idx[1] = 3;
nnz_chunk_idx[0] = 0;
}
else if (non_zero_cnt == 1) {
memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
nnz_chunk_idx[1] = 3;
}
}
/// Setup metadata
uint8_t meta_chunk = 0;
for (int i = 0; i < PhysicalSubChunk; i++) {
meta_chunk = static_cast<uint8_t>(meta_chunk | (encode_in_chunk_idx_legacy<ElementA>(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma)));
for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) {
tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j];
}
}
meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{})));
}
}
}
template<
class ProblemShape_,
class ElementA_,
class LayoutATag_,
class SparseConfig_
>
class SM90StructuredSparseCompressorLegacy {
public:
using SparseConfig = SparseConfig_;
using ProblemShape = ProblemShape_;
// * EltA
using ElementA = ElementA_;
using ElementAUint = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
static constexpr bool IsRuntimeDataTypeA = cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float8_t> ||
cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float6_t> ||
cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float4_t>;
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA,
cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>,
ElementA>;
using ElementAMma = typename SparseConfig::ElementAMma;
using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw;
using ElementASparsity = typename SparseConfig::ElementASparsity;
using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity;
using LayoutATag = LayoutATag_;
using LayoutA = LayoutATag;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutATag>;
// * EltE
using ElementEMma = typename SparseConfig::ElementEMma;
using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw;
using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity;
// * AtomE
using TensorEAtom = typename SparseConfig::TensorEAtom;
using TensorEAtomK = typename SparseConfig::TensorEAtomK;
using TensorEAtomM = typename SparseConfig::TensorEAtomM;
static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
// * Alignment
static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{};
static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{};
// Required by `device_kernel`
static constexpr int MaxThreadsPerBlock = 1;
static constexpr int MinBlocksPerMultiprocessor = 1;
using ArchTag = arch::Sm90;
struct SharedStorage {
/* empty, no smem needed */
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
struct TransformArguments {
ArrayElementA const* ptr_A{nullptr};
StrideA dA{};
ArrayElementA* ptr_ACompress{nullptr};
ElementEMmaRaw* ptr_E{nullptr};
};
using TransformParams = TransformArguments;
struct Arguments {
ProblemShape problem_shape{};
TransformArguments transform{};
KernelHardwareInfo hw_info{};
};
struct Params {
ProblemShape problem_shape{};
TransformParams transform{};
KernelHardwareInfo hw_info{};
void* workspace = nullptr;
};
static Params
to_underlying_arguments(Arguments & args, void* workspace) {
return Params{{args.problem_shape},
{args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E},
{args.hw_info},
workspace};
}
static Status
can_implement(Arguments const& args) {
auto [M, N, K, L] = args.problem_shape;
if (K % LogicalElemsAPerChunk != 0) {
CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n");
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
static size_t
get_workspace_size(Arguments const& args) {
auto problem = args.problem_shape;
const int m = cute::size<0>(problem);
const int k = cute::size<2>(problem);
const int l = cute::size<3>(problem);
const int metadata_k = round_up(k, TensorEAlignmentK);
const int metadata_m = round_up(m, TensorEAlignmentM);
const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l;
return metadata_bytes;
}
static Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
cudaError_t cuda_error;
auto workspace_size = get_workspace_size(args);
if (workspace_size == 0) {
return Status::kSuccess;
} else if (workspace == nullptr) {
return Status::kErrorInternal;
}
cudaPointerAttributes attri;
cuda_error = cudaPointerGetAttributes(&attri, workspace);
if (cuda_error != cudaSuccess) {
return Status::kErrorInternal;
}
if ( attri.type == cudaMemoryTypeDevice ) {
#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER
CUTLASS_ASSERT(cuda_adapter);
if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast<uint8_t>(0), workspace_size, stream)) {
return Status::kErrorInternal;
}
#else
cudaMemsetAsync(workspace, 0, workspace_size, stream);
cuda_error = cudaGetLastError();
if (cuda_error != cudaSuccess) {
return Status::kErrorInternal;
}
#endif
} else {
memset(workspace, 0, workspace_size);
}
return Status::kSuccess;
}
static dim3
get_grid_shape(Params const& params) {
return dim3(1, 1, 1);
}
static dim3
get_block_shape() {
return dim3(1, 1, 1);
}
CUTE_HOST_DEVICE
void
operator()(Params params, char* smem_buf = nullptr) {
run(params, smem_buf);
}
CUTE_HOST_DEVICE
static void
run(Params params, char* smem_buf = nullptr) {
do_compress_device_host(params);
}
private:
CUTE_HOST_DEVICE
static void
do_compress_device_host(Params params) {
auto [m, n, k, l] = params.problem_shape;
auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform;
auto workspace = params.workspace;
const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK;
const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM;
const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK;
const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM;
const int k_compressed = aligned_k / ElementASparsity{};
// Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here.
cute::Tensor tensorA = make_tensor(recast_ptr<ElementAUint>(ptr_A), make_layout(make_shape(m, k, l), dA));
cute::Tensor tensorAc = make_tensor(recast_ptr<ElementAUint>(ptr_ACompress),
make_shape(aligned_m, k_compressed, l),
make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l)));
cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr<sparse_elem<ElementEMmaSparsity{},ElementEMmaRaw>>(workspace),
make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l),
make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k));
cute::Tensor tensorE_raw_compress = recast<uint8_t>(tensorE_raw_compress_logical);
// The following vars are all logical.
int atom_m = size<0>(TensorEAtom{});
int atom_k = size<1>(TensorEAtom{});
int tiled_m = metadata_m / atom_m;
int tiled_ke = metadata_k / atom_k;
// Col major when viewing atoms
int stride_tile_m = cosize(TensorEAtom{});
int stride_tile_ke = atom_k * metadata_m;
// Logical metadata tensor
cute::Tensor tensorE_logical = make_tensor(recast_ptr<sparse_elem<ElementEMmaSparsity{},ElementEMmaRaw>>(ptr_E),
make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m),
append(shape<1>(TensorEAtom{}), tiled_ke),
shape<2>(tensorE_raw_compress_logical)),
make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m),
append(stride<1>(TensorEAtom{}), stride_tile_ke),
stride<2>(tensorE_raw_compress_logical))));
// Physical metadata tensor
cute::Tensor tensorE = recast<uint8_t>(tensorE_logical);
// void do_init()
cute::clear(tensorAc);
cute::clear(tensorE_raw_compress);
// void do_raw_compress()
using TileStepA = Int<LogicalElemsAPerChunk * 2>;
using TileStepAc = Int<TileStepA{} / 2>;
cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _));
cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _));
for (int batch_idx = 0; batch_idx < l; batch_idx++) {
for (int m_idx = 0; m_idx < m; m_idx++) {
for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) {
int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{}));
detail::compress_two_chunks_legacy<SparseConfig>(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx),
tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx),
tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx),
effective_elems);
}
}
}
// void do_reorder()
// Fast path when we don't permute.
if constexpr (sizeof_bits_v<ElementAUint> <= 8) {
memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size());
}
else {
cute::copy(tensorE_raw_compress, tensorE);
}
#if 0
print("--> TensorA\n");
auto tensorA_eltA = cute::recast<ElementA>(tensorA);
cute::print_tensor(tensorA_eltA); printf("\n\n");
print("--> REF TensorAC\n");
auto tensorAc_eltA = cute::recast<ElementA>(tensorAc);
cute::print_tensor(tensorAc_eltA); printf("\n\n");
print("--> REF TensorE\n");
cute::print_tensor(tensorE); printf("\n\n");
#endif
}
};
} // namespace kernel
} // namespace transform
} // namespace cutlass

View File

@ -0,0 +1,876 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* @brief Test for structured sparse gemm compressor device kernel
*/
#pragma once
#include <cuda_runtime_api.h> // cudaGetLastError
#include <cstdint> // uint64_t
#include <cstdio> // printf
#include <cstdlib> // malloc
#include <iostream> // std::cout
#include <vector>
#include <array>
#include "cute/layout.hpp" // cute::make_shape
#include "cute/util/type_traits.hpp" // cute::is_same_v
#include "cutlass/coord.h" // cutlass::make_Coord
#include "cutlass/cutlass.h" // cutlass::Status
#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo
#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory
#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_
#include "cutlass/tensor_view.h" // cutlass::TensorView
#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility
#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation
#include "cutlass/util/distribution.h" // cutlass::Distribution
#include "cutlass/util/host_tensor.h" // cutlass::HostTensor
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals
#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill
#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor
#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE
#define CUDA_CHECK_FALSE(cuda_error) \
{ \
if (cuda_error != cudaSuccess) { \
printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \
return false; \
} \
}
#define CUDA_CHECK(cuda_error) \
{ \
if (cuda_error != cudaSuccess) { \
printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \
return; \
} \
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// * Test Bed
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace test
{
namespace transform
{
namespace device
{
// Helper Functions
template <typename Element, typename Layout>
bool
initialize_tensor(cutlass::TensorView<Element, Layout> view, cutlass::Distribution::Kind dist_kind, uint64_t seed)
{
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 1;
scope_min = -1;
} else {
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else {
EXPECT_TRUE(false) << "Not implemented";
return false;
}
return true;
}
// Testbed
template <typename Compressor_>
struct TestbedSparseGemmCompressor {
public:
using Compressor = Compressor_;
using CompressorKernel = typename Compressor::TransformKernel;
using ElementA = typename CompressorKernel::ElementA;
using LayoutATag = typename CompressorKernel::LayoutATag;
using StrideA = typename CompressorKernel::StrideA;
using ArrayElementA =
ElementA
;
using ElementE = typename CompressorKernel::ElementEMmaRaw;
using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor
using SparseConfig = typename CompressorKernel::SparseConfig;
using ProblemShapeType = typename CompressorKernel::ProblemShape;
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShapeType,
ElementA,
LayoutATag,
SparseConfig>;
using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy<
ProblemShapeType,
ElementA,
LayoutATag,
SparseConfig>;
using CompressorHost = cutlass::transform::device::TransformUniversalAdapter<CompressorKernelHost>;
static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk;
static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk;
struct Data {
// Data Storage
cutlass::HostTensor<ArrayElementA, LayoutATag> tensor_A;
cutlass::HostTensor<ArrayElementA, LayoutATag> tensor_A_Comp;
cutlass::HostTensor<ElementE, LayoutETag> tensor_E;
cutlass::HostTensor<ArrayElementA, LayoutATag> tensor_A_Comp_ref;
cutlass::HostTensor<ElementE, LayoutETag> tensor_E_ref;
};
struct CudaRAII {
cudaStream_t stream;
cudaEvent_t start;
cudaEvent_t stop;
CudaRAII(){
CUDA_CHECK(cudaStreamCreate( &stream ));
CUDA_CHECK(cudaEventCreate( &start ));
CUDA_CHECK(cudaEventCreate( &stop ));
};
CudaRAII(const CudaRAII&) = delete;
CudaRAII& operator=(const CudaRAII&) = delete;
CudaRAII(CudaRAII&&) = delete;
CudaRAII& operator=(CudaRAII&&) = delete;
~CudaRAII(){
CUDA_CHECK(cudaStreamDestroy( stream ));
CUDA_CHECK(cudaEventDestroy( start ));
CUDA_CHECK(cudaEventDestroy( stop ));
}
};
public:
TestbedSparseGemmCompressor(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 7)
: init_A(init_A_)
, init_E(init_E_)
, init_A_Comp(init_A_Comp_)
, seed(seed_)
{
}
bool valid_test(ProblemShapeType problem_shape_MNKL)
{
const int GemmK = cute::size<2>(problem_shape_MNKL);
if ( GemmK % LogicalElemsAPerChunk != 0 ) {
printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n");
return false;
}
return true;
}
bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas)
{
CUDA_CHECK_FALSE(cudaGetLastError());
// In unit of ElementARaw
const int GemmM = cute::size<0>(problem_shape_MNKL);
const int GemmN = cute::size<1>(problem_shape_MNKL);
const int GemmK = cute::size<2>(problem_shape_MNKL);
const int GemmL = cute::size<3>(problem_shape_MNKL);
// Compressor utility to get allocated data size
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL));
CompressorUtility compressor_utility(problem_shape_MNKL, stride_a);
// TensorA
// In unit of ElementARaw, after alignment requirement
// M-dim: no alignment requirement
// K-dim: multiplier of chunk size
// TensorA Compressed
// In unit of ElementARaw, after alignment requirement
// M-dim: TMA alignment
// K-dim: TMA alignment
const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical();
const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical();
// TensorE
// In unit of ElementE (uint8_t), after alignment requirement
// M-dim: TensorEAtom_M alignment
// K-dim: TensorEAtom_K alignment
const int GemmMAlignedE = compressor_utility.get_metadata_m_physical();
const int GemmKAlignedE = compressor_utility.get_metadata_k_physical();
auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK);
auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE);
auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC);
typename LayoutATag::Stride stride_factor_A;
typename LayoutETag::Stride stride_factor_E;
datas.tensor_A.resize(a_coord,
cutlass::layout::Affine2Layout_Factory<LayoutATag>::layout_factory(a_coord, stride_factor_A));
datas.tensor_A_Comp.resize(a_comp_coord,
cutlass::layout::Affine2Layout_Factory<LayoutATag>::layout_factory(a_comp_coord, stride_factor_A));
datas.tensor_A_Comp_ref.resize(a_comp_coord,
cutlass::layout::Affine2Layout_Factory<LayoutATag>::layout_factory(a_comp_coord, stride_factor_A),
false);
datas.tensor_E.resize(e_coord,
cutlass::layout::Affine2Layout_Factory<LayoutETag>::layout_factory(e_coord, stride_factor_E));
datas.tensor_E_ref.resize(e_coord,
cutlass::layout::Affine2Layout_Factory<LayoutETag>::layout_factory(e_coord, stride_factor_E),
false);
EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1));
EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2));
EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3));
EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4));
EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5));
compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6);
// Check for failed devide
CUDA_CHECK_FALSE(cudaGetLastError());
datas.tensor_A.sync_device();
datas.tensor_A_Comp.sync_device();
datas.tensor_E.sync_device();
// Check for failed devide
CUDA_CHECK_FALSE(cudaGetLastError());
return true;
}
bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr)
{
CudaRAII cuda_raii;
const int GemmM = cute::size<0>(problem_shape_MNKL);
const int GemmN = cute::size<1>(problem_shape_MNKL);
const int GemmK = cute::size<2>(problem_shape_MNKL);
const int GemmL = cute::size<3>(problem_shape_MNKL);
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL));
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{
{GemmM, GemmN, GemmK, GemmL},
{datas.tensor_A.device_data(),
stride_a,
datas.tensor_A_Comp.device_data(),
datas.tensor_E.device_data()},
{hw_info}
};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status {cutlass::Status::kSuccess };
status = compressor_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
CUDA_CHECK_FALSE(cudaGetLastError());
}
status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream);
if (status != cutlass::Status::kSuccess) {
CUDA_CHECK_FALSE(cudaGetLastError());
}
CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream));
CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream));
status = compressor_op.run(cuda_raii.stream);
if (status != cutlass::Status::kSuccess) {
CUDA_CHECK_FALSE(cudaGetLastError());
}
CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream));
CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop));
CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream));
if ( time != nullptr ){
CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop));
}
datas.tensor_A_Comp.sync_host();
datas.tensor_E.sync_host();
#if 0
{
printf("\n--> DEVICE OUTPUT\n");
printf("datas.tensor_A\n");
std::cout << datas.tensor_A.host_view() << std::endl << std::endl;
printf("datas.tensor_A_Comp\n");
std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl;
printf("datas.tensor_E\n");
std::cout << datas.tensor_E.host_view() << std::endl << std::endl;
}
#endif
return true;
}
bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas)
{
const int GemmM = cute::size<0>(problem_shape_MNKL);
const int GemmN = cute::size<1>(problem_shape_MNKL);
const int GemmK = cute::size<2>(problem_shape_MNKL);
const int GemmL = cute::size<3>(problem_shape_MNKL);
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL));
typename CompressorKernelHost::Arguments arguments{
{GemmM, GemmN, GemmK, GemmL},
{datas.tensor_A.host_data(),
stride_a,
datas.tensor_A_Comp_ref.host_data(),
datas.tensor_E_ref.host_data()},
{}};
const auto can_imp = CompressorKernelHost::can_implement(arguments);
if (can_imp != cutlass::Status::kSuccess) {
printf("can_implement() check failed\n");
return false;
}
// Relies on std::vector for RAII
auto workspace_size =
static_cast<std::vector<uint8_t>::size_type>(CompressorKernelHost::get_workspace_size(arguments));
std::vector<uint8_t> workspace_vector(workspace_size);
auto workspace = static_cast<void*>(workspace_vector.data());
cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace);
if (status != cutlass::Status::kSuccess) {
printf("initialize_workspace() failed\n");
return false;
}
auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace);
CompressorKernelHost::run(params);
return true;
}
bool compare_reference(Data& datas)
{
bool check_tensor_a_compressed =
cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view());
if (!check_tensor_a_compressed) {
printf("A-Compressed Mismatch\n");
}
bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view());
if (!check_tensor_e) {
printf("E Mismatch\n");
}
return check_tensor_a_compressed && check_tensor_e;
}
bool run_auto_small()
{
return run_auto(true);
}
bool run_auto(bool run_small = false)
{
constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
constexpr int GemmN = 1;
using ProblemType = typename std::array<int, 4>;
std::vector<ProblemType> problems;
const std::vector<ProblemType> problems_multiplier_of_tensor_e_atom = {
// * Regular Cases (multiplier of TensorEAlignment)
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3},
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3},
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3},
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3},
};
const std::vector<ProblemType> problems_multiplier_of_tensor_e_atom_large = {
// * Large Case (multiplier of TensorEAlignment)
{TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1},
// {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2},
// {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3},
};
const std::vector<ProblemType> problems_multiplier_of_twochunk {
// * Corner Cases
{4, GemmN, LogicalElemsAPerChunk * 2, 1},
{4, GemmN, LogicalElemsAPerChunk * 4, 1},
{4, GemmN, LogicalElemsAPerChunk * 6, 1},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
{4, GemmN, LogicalElemsAPerChunk * 2, 2},
{4, GemmN, LogicalElemsAPerChunk * 4, 2},
{4, GemmN, LogicalElemsAPerChunk * 6, 2},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
{4, GemmN, LogicalElemsAPerChunk * 2, 3},
{4, GemmN, LogicalElemsAPerChunk * 4, 3},
{4, GemmN, LogicalElemsAPerChunk * 6, 3},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
{32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1},
{32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1},
{32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
{32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2},
{32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2},
{32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
{32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3},
{32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3},
{32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
};
const std::vector<ProblemType> problems_multiplier_of_onechunk {
{4, GemmN, LogicalElemsAPerChunk * 1, 1},
{4, GemmN, LogicalElemsAPerChunk * 3, 1},
{4, GemmN, LogicalElemsAPerChunk * 5, 1},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
{4, GemmN, LogicalElemsAPerChunk * 1, 2},
{4, GemmN, LogicalElemsAPerChunk * 3, 2},
{4, GemmN, LogicalElemsAPerChunk * 5, 2},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
{4, GemmN, LogicalElemsAPerChunk * 1, 3},
{4, GemmN, LogicalElemsAPerChunk * 3, 3},
{4, GemmN, LogicalElemsAPerChunk * 5, 3},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
{32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1},
{32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1},
{32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
{32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2},
{32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2},
{32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
{32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3},
{32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3},
{32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3},
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3},
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
};
// Run small only run multiplier of chunk size cases
if (run_small) {
problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end());
}
// Run full run all corner cases
else {
problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end());
problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end());
problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end());
problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end());
}
for (const auto& problem_shape_MNKL : problems) {
const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL;
bool passed = run({GemmM, GemmN, GemmK, GemmL});
printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL");
CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL");
if (not passed) {
return false;
}
}
return true;
}
bool run(ProblemShapeType problem_shape_MNKL)
{
// Check if valid test
if (not valid_test(problem_shape_MNKL)) {
CUTLASS_TRACE_HOST("valid_test() fail\n");
return false;
}
// Data Storage
Data datas;
// Initialize Data
if (not initialize(problem_shape_MNKL, datas)) {
CUTLASS_TRACE_HOST("initialize() fail\n");
return false;
}
// Run Compressor (Host Ref)
if (not run_host_ref(problem_shape_MNKL, datas)) {
CUTLASS_TRACE_HOST("run_host() fail\n");
return false;
}
// Run Compressor (Device)
if (not run_device(problem_shape_MNKL, datas)) {
CUTLASS_TRACE_HOST("run_device() fail\n");
return false;
}
// Verify
if (not compare_reference(datas)) {
CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n");
printf("compare_reference() DEVICE <-> LEGACY HOST fail\n");
return false;
}
// else {
// printf("DEVICE <-> HOST PASS\n");
// }
return true;
}
bool benchmark(ProblemShapeType problem_shape_MNKL) {
const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL;
printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL);
// Check if valid test
if (valid_test(problem_shape_MNKL) == false) {
CUTLASS_TRACE_HOST("valid_test() fail\n");
return false;
}
// 2 warm-up iterations and 10 timing iterations
constexpr int num_warmup = 5;
constexpr int num_iter = 10;
// Duplicate data to mimic cold cache
Data data[num_warmup + num_iter];
double total_time_milliseconds{0.0};
for (int i = 0; i < num_warmup + num_iter; ++i ) {
printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i );
auto& datum_i = data[i];
// Initialize Data
if (initialize(problem_shape_MNKL, datum_i) == false) {
CUTLASS_TRACE_HOST("initialize() fail\n");
return false;
}
// Run Compressor (Device)
double time_i_milliseconds{0.0f};
if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) {
CUTLASS_TRACE_HOST("run_device() fail\n");
return false;
}
if ( i >= num_warmup ) {
total_time_milliseconds += time_i_milliseconds;
}
}
const double mean_time_milliseconds = total_time_milliseconds / num_iter;
printf("Mean time (ms): %.5f\n", mean_time_milliseconds);
return true;
}
public:
// Data Init Setting
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_A_Comp;
cutlass::Distribution::Kind init_E;
uint64_t seed;
};
} // namespace device
} // namespace transform
} // namespace test