CUTLASS 3.6.0 (#1850)
* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
58
test/unit/transform/device/CMakeLists.txt
Normal file
58
test/unit/transform/device/CMakeLists.txt
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#
|
||||
# Compress Kernel
|
||||
#
|
||||
|
||||
add_custom_target(
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor
|
||||
DEPENDS
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f32
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f16
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f8
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f32
|
||||
|
||||
sm90_sparse_gemm_compressor_f32.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f16
|
||||
|
||||
sm90_sparse_gemm_compressor_f16.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_sm90_structured_sparse_gemm_compressor_f8
|
||||
|
||||
sm90_sparse_gemm_compressor_f8.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major
|
||||
#include "cutlass/arch/config.h" // CUTLASS_ARCH_MMA_SM90_SUPPORTED
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl" // gmma_ss_tag_to_major_A
|
||||
#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" // Sm90GemmSparseConfig
|
||||
#include "testbed_sparse_gemm_compressor.hpp" // TestbedSparseGemmCompressor
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// * Test Plan
|
||||
// ElementA : fp16
|
||||
// LayoutA : row / col
|
||||
// Gemm : 1x 2x 3x multiplier of alignment requirement. corner case that smaller than alignment requirement
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f16_t)
|
||||
{
|
||||
// Test Settings
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
|
||||
// Deduct From Test Setting
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
|
||||
using ElementAMma = cute::sparse_elem<2, ElementA>;
|
||||
using ElementEMma = cute::sparse_elem<8, uint8_t>;
|
||||
|
||||
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<32>>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::
|
||||
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
// Test Bed
|
||||
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
|
||||
EXPECT_TRUE(testbed.run_auto());
|
||||
}
|
||||
|
||||
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f16_n)
|
||||
{
|
||||
// Test Settings
|
||||
using ElementA = cutlass::bfloat16_t;
|
||||
using LayoutATag = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Deduct From Test Setting
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
|
||||
using ElementAMma = cute::sparse_elem<2, ElementA>;
|
||||
using ElementEMma = cute::sparse_elem<8, uint8_t>;
|
||||
|
||||
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<64>>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::
|
||||
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
// Test Bed
|
||||
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
|
||||
EXPECT_TRUE(testbed.run_auto());
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major
|
||||
#include "cutlass/arch/config.h" // CUTLASS_ARCH_MMA_SM90_SUPPORTED
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl" // gmma_ss_tag_to_major_A
|
||||
#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" // Sm90GemmSparseConfig
|
||||
#include "testbed_sparse_gemm_compressor.hpp" // TestbedSparseGemmCompressor
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// * Test Plan
|
||||
// ElementA : fp32
|
||||
// LayoutA : row / col
|
||||
// Gemm : 1x 2x 3x multiplier of alignment requirement. corner case that smaller than alignment requirement
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f32_t)
|
||||
{
|
||||
// Test Settings
|
||||
using ElementA = float;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
|
||||
// Deduct From Test Setting
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
|
||||
using ElementAMma = cute::sparse_elem<2, ElementA>;
|
||||
using ElementEMma = cute::sparse_elem<4, uint8_t>;
|
||||
|
||||
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<16>>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::
|
||||
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
// Test Bed
|
||||
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
|
||||
EXPECT_TRUE(testbed.run_auto());
|
||||
}
|
||||
|
||||
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f32_n)
|
||||
{
|
||||
// Test Settings
|
||||
using ElementA = cutlass::tfloat32_t;
|
||||
using LayoutATag = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Deduct From Test Setting
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
|
||||
using ElementAMma = cute::sparse_elem<2, ElementA>;
|
||||
using ElementEMma = cute::sparse_elem<4, uint8_t>;
|
||||
|
||||
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<32>>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::
|
||||
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
// Test Bed
|
||||
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
|
||||
EXPECT_TRUE(testbed.run_auto());
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
95
test/unit/transform/device/sm90_sparse_gemm_compressor_f8.cu
Normal file
95
test/unit/transform/device/sm90_sparse_gemm_compressor_f8.cu
Normal file
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major
|
||||
#include "cutlass/arch/config.h" // CUTLASS_ARCH_MMA_SM90_SUPPORTED
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl" // gmma_ss_tag_to_major_A
|
||||
#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" // Sm90GemmSparseConfig
|
||||
#include "testbed_sparse_gemm_compressor.hpp" // TestbedSparseGemmCompressor
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// * Test Plan
|
||||
// ElementA : fp8
|
||||
// LayoutA : row / col
|
||||
// Gemm : 1x 2x 3x multiplier of alignment requirement. corner case that smaller than alignment requirement
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f8_t)
|
||||
{
|
||||
// Test Settings
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
|
||||
// Deduct From Test Setting
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
|
||||
using ElementAMma = cute::sparse_elem<2, ElementA>;
|
||||
using ElementEMma = cute::sparse_elem<8, uint8_t>;
|
||||
|
||||
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<64>>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::
|
||||
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
// Test Bed
|
||||
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
|
||||
EXPECT_TRUE(testbed.run_auto());
|
||||
}
|
||||
|
||||
TEST(SM90_Structured_Sparse_Gemm_Compressor_Device, f8_n)
|
||||
{
|
||||
// Test Settings
|
||||
using ElementA = cutlass::float_e5m2_t;
|
||||
using LayoutATag = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Deduct From Test Setting
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A<LayoutATag>();
|
||||
using ElementAMma = cute::sparse_elem<2, ElementA>;
|
||||
using ElementEMma = cute::sparse_elem<8, uint8_t>;
|
||||
|
||||
using SparseConfig = cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma, cute::Int<64>>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::
|
||||
StructuredSparseCompressor<cute::Shape<int, int, int, int>, ElementA, LayoutATag, SparseConfig, cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
// Test Bed
|
||||
test::transform::device::TestbedSparseGemmCompressor<Compressor> testbed;
|
||||
EXPECT_TRUE(testbed.run_auto());
|
||||
}
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -0,0 +1,480 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Compress utils specific for SM90 structure sparse kernels
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm> // std::fill
|
||||
#include <array> // std::array
|
||||
#include <cstdio>
|
||||
#include <random> // std::mt19937
|
||||
|
||||
#include "cute/container/bit_field.hpp" // cute::bit_field
|
||||
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
|
||||
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor
|
||||
#include "cutlass/arch/arch.h" // cutlass::arch::Sm90
|
||||
#include "cutlass/cutlass.h" // cutlass::Status
|
||||
#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t
|
||||
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
||||
#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
|
||||
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
static uint8_t
|
||||
encode_in_chunk_idx_legacy(int in_chunk_idx){
|
||||
if (sizeof(T) == 4) {
|
||||
return in_chunk_idx == 0 ? 0b0100 : 0b1110;
|
||||
}
|
||||
else {
|
||||
uint8_t res = 0;
|
||||
if (in_chunk_idx == 0) {
|
||||
res = 0b00;
|
||||
}
|
||||
else if (in_chunk_idx == 1) {
|
||||
res = 0b01;
|
||||
}
|
||||
else if (in_chunk_idx == 2) {
|
||||
res = 0b10;
|
||||
}
|
||||
else {
|
||||
res = 0b11;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class SparseConfig,
|
||||
class EngineA,
|
||||
class LayoutA,
|
||||
class EngineAc,
|
||||
class LayoutAc
|
||||
>
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void
|
||||
compress_two_chunks_legacy(
|
||||
Tensor<EngineA, LayoutA> tensorA,
|
||||
Tensor<EngineAc, LayoutAc> tensorAc,
|
||||
uint8_t& meta_two_chunk,
|
||||
int effective_elems) {
|
||||
|
||||
using ElementA = typename EngineAc::value_type;
|
||||
|
||||
static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
|
||||
static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
|
||||
static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
|
||||
static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{};
|
||||
static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
||||
static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
||||
|
||||
/*
|
||||
Legal metadata chunk in SM90
|
||||
Index Bin HEX
|
||||
0, 1 0b0100 4
|
||||
1, 2 0b1001 9
|
||||
2, 3 0b1110 E
|
||||
0, 2 0b1000 8
|
||||
1, 3 0b1101 D
|
||||
0, 3 0b1100 C
|
||||
2, 1 0b0110 6 (Not used)
|
||||
-----------------------------------
|
||||
TF32
|
||||
0 0b0100 4
|
||||
1 0b1110 E
|
||||
*/
|
||||
|
||||
if (effective_elems <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// initialize
|
||||
// 0 is the initial value for this function while 0x44 is the initial value for hardware.
|
||||
meta_two_chunk = 0;
|
||||
|
||||
for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) {
|
||||
// If Only One Chunk within this Two Chunk
|
||||
if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) {
|
||||
break;
|
||||
}
|
||||
/// init result;
|
||||
int non_zero_cnt = 0;
|
||||
int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 };
|
||||
ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} };
|
||||
|
||||
for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) {
|
||||
bool is_nz = true;
|
||||
ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} };
|
||||
/// Check if subchunk is non-zero
|
||||
for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) {
|
||||
int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx;
|
||||
subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0);
|
||||
|
||||
if (subchunk_elems[elem_idx] != ElementA(0)) {
|
||||
if (non_zero_cnt >= PhysicalSubChunk) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#else
|
||||
throw std::runtime_error("Found extra non-zero elements in a chunk!\n");
|
||||
#endif
|
||||
}
|
||||
is_nz = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// There is non-zero element in the subchunk
|
||||
if(!is_nz) {
|
||||
nnz_chunk_idx[non_zero_cnt] = subchunk_idx;
|
||||
memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
|
||||
non_zero_cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Special cases
|
||||
nnz == 1 and non-tf32 and nnz_idx = 3
|
||||
*/
|
||||
ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} };
|
||||
if constexpr (sizeof_bits_v<ElementA> < 32) {
|
||||
if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) {
|
||||
memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
|
||||
memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
|
||||
nnz_chunk_idx[1] = 3;
|
||||
nnz_chunk_idx[0] = 0;
|
||||
}
|
||||
else if (non_zero_cnt == 1) {
|
||||
memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw);
|
||||
nnz_chunk_idx[1] = 3;
|
||||
}
|
||||
}
|
||||
|
||||
/// Setup metadata
|
||||
uint8_t meta_chunk = 0;
|
||||
for (int i = 0; i < PhysicalSubChunk; i++) {
|
||||
meta_chunk = static_cast<uint8_t>(meta_chunk | (encode_in_chunk_idx_legacy<ElementA>(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma)));
|
||||
for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) {
|
||||
tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j];
|
||||
}
|
||||
}
|
||||
meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{})));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class ProblemShape_,
|
||||
class ElementA_,
|
||||
class LayoutATag_,
|
||||
class SparseConfig_
|
||||
>
|
||||
class SM90StructuredSparseCompressorLegacy {
|
||||
public:
|
||||
using SparseConfig = SparseConfig_;
|
||||
using ProblemShape = ProblemShape_;
|
||||
|
||||
// * EltA
|
||||
using ElementA = ElementA_;
|
||||
using ElementAUint = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
||||
static constexpr bool IsRuntimeDataTypeA = cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float8_t> ||
|
||||
cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float6_t> ||
|
||||
cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float4_t>;
|
||||
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA,
|
||||
cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>,
|
||||
ElementA>;
|
||||
using ElementAMma = typename SparseConfig::ElementAMma;
|
||||
using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw;
|
||||
using ElementASparsity = typename SparseConfig::ElementASparsity;
|
||||
using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity;
|
||||
using LayoutATag = LayoutATag_;
|
||||
using LayoutA = LayoutATag;
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutATag>;
|
||||
|
||||
// * EltE
|
||||
using ElementEMma = typename SparseConfig::ElementEMma;
|
||||
using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw;
|
||||
using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity;
|
||||
|
||||
// * AtomE
|
||||
using TensorEAtom = typename SparseConfig::TensorEAtom;
|
||||
using TensorEAtomK = typename SparseConfig::TensorEAtomK;
|
||||
using TensorEAtomM = typename SparseConfig::TensorEAtomM;
|
||||
|
||||
static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
|
||||
static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
|
||||
static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
|
||||
static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
||||
static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
||||
|
||||
// * Alignment
|
||||
static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
|
||||
static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
|
||||
static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{};
|
||||
static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{};
|
||||
|
||||
// Required by `device_kernel`
|
||||
static constexpr int MaxThreadsPerBlock = 1;
|
||||
static constexpr int MinBlocksPerMultiprocessor = 1;
|
||||
using ArchTag = arch::Sm90;
|
||||
|
||||
struct SharedStorage {
|
||||
/* empty, no smem needed */
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
struct TransformArguments {
|
||||
ArrayElementA const* ptr_A{nullptr};
|
||||
StrideA dA{};
|
||||
ArrayElementA* ptr_ACompress{nullptr};
|
||||
ElementEMmaRaw* ptr_E{nullptr};
|
||||
};
|
||||
|
||||
using TransformParams = TransformArguments;
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_shape{};
|
||||
TransformArguments transform{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_shape{};
|
||||
TransformParams transform{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
void* workspace = nullptr;
|
||||
};
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments & args, void* workspace) {
|
||||
return Params{{args.problem_shape},
|
||||
{args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E},
|
||||
{args.hw_info},
|
||||
workspace};
|
||||
}
|
||||
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape;
|
||||
if (K % LogicalElemsAPerChunk != 0) {
|
||||
CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n");
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto problem = args.problem_shape;
|
||||
const int m = cute::size<0>(problem);
|
||||
const int k = cute::size<2>(problem);
|
||||
const int l = cute::size<3>(problem);
|
||||
const int metadata_k = round_up(k, TensorEAlignmentK);
|
||||
const int metadata_m = round_up(m, TensorEAlignmentM);
|
||||
const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l;
|
||||
return metadata_bytes;
|
||||
}
|
||||
|
||||
static Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
cudaError_t cuda_error;
|
||||
|
||||
auto workspace_size = get_workspace_size(args);
|
||||
if (workspace_size == 0) {
|
||||
return Status::kSuccess;
|
||||
} else if (workspace == nullptr) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
cudaPointerAttributes attri;
|
||||
cuda_error = cudaPointerGetAttributes(&attri, workspace);
|
||||
if (cuda_error != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
if ( attri.type == cudaMemoryTypeDevice ) {
|
||||
#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER
|
||||
CUTLASS_ASSERT(cuda_adapter);
|
||||
if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast<uint8_t>(0), workspace_size, stream)) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
#else
|
||||
cudaMemsetAsync(workspace, 0, workspace_size, stream);
|
||||
cuda_error = cudaGetLastError();
|
||||
if (cuda_error != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
memset(workspace, 0, workspace_size);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
return dim3(1, 1, 1);
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(1, 1, 1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
operator()(Params params, char* smem_buf = nullptr) {
|
||||
run(params, smem_buf);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
static void
|
||||
run(Params params, char* smem_buf = nullptr) {
|
||||
do_compress_device_host(params);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
static void
|
||||
do_compress_device_host(Params params) {
|
||||
auto [m, n, k, l] = params.problem_shape;
|
||||
auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform;
|
||||
auto workspace = params.workspace;
|
||||
|
||||
const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK;
|
||||
const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM;
|
||||
const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK;
|
||||
const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM;
|
||||
const int k_compressed = aligned_k / ElementASparsity{};
|
||||
|
||||
// Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here.
|
||||
cute::Tensor tensorA = make_tensor(recast_ptr<ElementAUint>(ptr_A), make_layout(make_shape(m, k, l), dA));
|
||||
|
||||
cute::Tensor tensorAc = make_tensor(recast_ptr<ElementAUint>(ptr_ACompress),
|
||||
make_shape(aligned_m, k_compressed, l),
|
||||
make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l)));
|
||||
|
||||
cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr<sparse_elem<ElementEMmaSparsity{},ElementEMmaRaw>>(workspace),
|
||||
make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l),
|
||||
make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k));
|
||||
|
||||
cute::Tensor tensorE_raw_compress = recast<uint8_t>(tensorE_raw_compress_logical);
|
||||
|
||||
// The following vars are all logical.
|
||||
int atom_m = size<0>(TensorEAtom{});
|
||||
int atom_k = size<1>(TensorEAtom{});
|
||||
int tiled_m = metadata_m / atom_m;
|
||||
int tiled_ke = metadata_k / atom_k;
|
||||
// Col major when viewing atoms
|
||||
int stride_tile_m = cosize(TensorEAtom{});
|
||||
int stride_tile_ke = atom_k * metadata_m;
|
||||
|
||||
// Logical metadata tensor
|
||||
cute::Tensor tensorE_logical = make_tensor(recast_ptr<sparse_elem<ElementEMmaSparsity{},ElementEMmaRaw>>(ptr_E),
|
||||
make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m),
|
||||
append(shape<1>(TensorEAtom{}), tiled_ke),
|
||||
shape<2>(tensorE_raw_compress_logical)),
|
||||
make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m),
|
||||
append(stride<1>(TensorEAtom{}), stride_tile_ke),
|
||||
stride<2>(tensorE_raw_compress_logical))));
|
||||
// Physical metadata tensor
|
||||
cute::Tensor tensorE = recast<uint8_t>(tensorE_logical);
|
||||
|
||||
// void do_init()
|
||||
cute::clear(tensorAc);
|
||||
cute::clear(tensorE_raw_compress);
|
||||
|
||||
// void do_raw_compress()
|
||||
using TileStepA = Int<LogicalElemsAPerChunk * 2>;
|
||||
using TileStepAc = Int<TileStepA{} / 2>;
|
||||
|
||||
cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _));
|
||||
cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _));
|
||||
|
||||
for (int batch_idx = 0; batch_idx < l; batch_idx++) {
|
||||
for (int m_idx = 0; m_idx < m; m_idx++) {
|
||||
for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) {
|
||||
int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{}));
|
||||
detail::compress_two_chunks_legacy<SparseConfig>(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx),
|
||||
tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx),
|
||||
tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx),
|
||||
effective_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// void do_reorder()
|
||||
// Fast path when we don't permute.
|
||||
if constexpr (sizeof_bits_v<ElementAUint> <= 8) {
|
||||
memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size());
|
||||
}
|
||||
else {
|
||||
cute::copy(tensorE_raw_compress, tensorE);
|
||||
}
|
||||
|
||||
#if 0
|
||||
print("--> TensorA\n");
|
||||
auto tensorA_eltA = cute::recast<ElementA>(tensorA);
|
||||
cute::print_tensor(tensorA_eltA); printf("\n\n");
|
||||
|
||||
print("--> REF TensorAC\n");
|
||||
auto tensorAc_eltA = cute::recast<ElementA>(tensorAc);
|
||||
cute::print_tensor(tensorAc_eltA); printf("\n\n");
|
||||
|
||||
print("--> REF TensorE\n");
|
||||
cute::print_tensor(tensorE); printf("\n\n");
|
||||
#endif
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
876
test/unit/transform/device/testbed_sparse_gemm_compressor.hpp
Normal file
876
test/unit/transform/device/testbed_sparse_gemm_compressor.hpp
Normal file
@ -0,0 +1,876 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
* @brief Test for structured sparse gemm compressor device kernel
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h> // cudaGetLastError
|
||||
|
||||
#include <cstdint> // uint64_t
|
||||
#include <cstdio> // printf
|
||||
#include <cstdlib> // malloc
|
||||
#include <iostream> // std::cout
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
#include "cute/layout.hpp" // cute::make_shape
|
||||
#include "cute/util/type_traits.hpp" // cute::is_same_v
|
||||
#include "cutlass/coord.h" // cutlass::make_Coord
|
||||
#include "cutlass/cutlass.h" // cutlass::Status
|
||||
#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo
|
||||
#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory
|
||||
#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_
|
||||
#include "cutlass/tensor_view.h" // cutlass::TensorView
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility
|
||||
#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation
|
||||
#include "cutlass/util/distribution.h" // cutlass::Distribution
|
||||
#include "cutlass/util/host_tensor.h" // cutlass::HostTensor
|
||||
#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride
|
||||
#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals
|
||||
#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill
|
||||
|
||||
#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor
|
||||
#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE
|
||||
|
||||
|
||||
#define CUDA_CHECK_FALSE(cuda_error) \
|
||||
{ \
|
||||
if (cuda_error != cudaSuccess) { \
|
||||
printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(cuda_error) \
|
||||
{ \
|
||||
if (cuda_error != cudaSuccess) { \
|
||||
printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \
|
||||
return; \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// * Test Bed
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace test
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
// Helper Functions
|
||||
template <typename Element, typename Layout>
|
||||
bool
|
||||
initialize_tensor(cutlass::TensorView<Element, Layout> view, cutlass::Distribution::Kind dist_kind, uint64_t seed)
|
||||
{
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
} else {
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
|
||||
else {
|
||||
EXPECT_TRUE(false) << "Not implemented";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Testbed
|
||||
template <typename Compressor_>
|
||||
struct TestbedSparseGemmCompressor {
|
||||
public:
|
||||
using Compressor = Compressor_;
|
||||
using CompressorKernel = typename Compressor::TransformKernel;
|
||||
|
||||
using ElementA = typename CompressorKernel::ElementA;
|
||||
using LayoutATag = typename CompressorKernel::LayoutATag;
|
||||
using StrideA = typename CompressorKernel::StrideA;
|
||||
using ArrayElementA =
|
||||
ElementA
|
||||
;
|
||||
|
||||
using ElementE = typename CompressorKernel::ElementEMmaRaw;
|
||||
using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor
|
||||
|
||||
using SparseConfig = typename CompressorKernel::SparseConfig;
|
||||
using ProblemShapeType = typename CompressorKernel::ProblemShape;
|
||||
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShapeType,
|
||||
ElementA,
|
||||
LayoutATag,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy<
|
||||
ProblemShapeType,
|
||||
ElementA,
|
||||
LayoutATag,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorHost = cutlass::transform::device::TransformUniversalAdapter<CompressorKernelHost>;
|
||||
|
||||
static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk;
|
||||
static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk;
|
||||
|
||||
struct Data {
|
||||
// Data Storage
|
||||
cutlass::HostTensor<ArrayElementA, LayoutATag> tensor_A;
|
||||
cutlass::HostTensor<ArrayElementA, LayoutATag> tensor_A_Comp;
|
||||
cutlass::HostTensor<ElementE, LayoutETag> tensor_E;
|
||||
cutlass::HostTensor<ArrayElementA, LayoutATag> tensor_A_Comp_ref;
|
||||
cutlass::HostTensor<ElementE, LayoutETag> tensor_E_ref;
|
||||
};
|
||||
|
||||
struct CudaRAII {
|
||||
cudaStream_t stream;
|
||||
cudaEvent_t start;
|
||||
cudaEvent_t stop;
|
||||
|
||||
CudaRAII(){
|
||||
CUDA_CHECK(cudaStreamCreate( &stream ));
|
||||
CUDA_CHECK(cudaEventCreate( &start ));
|
||||
CUDA_CHECK(cudaEventCreate( &stop ));
|
||||
};
|
||||
|
||||
CudaRAII(const CudaRAII&) = delete;
|
||||
CudaRAII& operator=(const CudaRAII&) = delete;
|
||||
CudaRAII(CudaRAII&&) = delete;
|
||||
CudaRAII& operator=(CudaRAII&&) = delete;
|
||||
|
||||
~CudaRAII(){
|
||||
CUDA_CHECK(cudaStreamDestroy( stream ));
|
||||
CUDA_CHECK(cudaEventDestroy( start ));
|
||||
CUDA_CHECK(cudaEventDestroy( stop ));
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
TestbedSparseGemmCompressor(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 7)
|
||||
: init_A(init_A_)
|
||||
, init_E(init_E_)
|
||||
, init_A_Comp(init_A_Comp_)
|
||||
, seed(seed_)
|
||||
{
|
||||
}
|
||||
|
||||
bool valid_test(ProblemShapeType problem_shape_MNKL)
|
||||
{
|
||||
const int GemmK = cute::size<2>(problem_shape_MNKL);
|
||||
|
||||
if ( GemmK % LogicalElemsAPerChunk != 0 ) {
|
||||
printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas)
|
||||
{
|
||||
CUDA_CHECK_FALSE(cudaGetLastError());
|
||||
|
||||
// In unit of ElementARaw
|
||||
const int GemmM = cute::size<0>(problem_shape_MNKL);
|
||||
const int GemmN = cute::size<1>(problem_shape_MNKL);
|
||||
const int GemmK = cute::size<2>(problem_shape_MNKL);
|
||||
const int GemmL = cute::size<3>(problem_shape_MNKL);
|
||||
|
||||
// Compressor utility to get allocated data size
|
||||
auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL));
|
||||
CompressorUtility compressor_utility(problem_shape_MNKL, stride_a);
|
||||
|
||||
// TensorA
|
||||
// In unit of ElementARaw, after alignment requirement
|
||||
// M-dim: no alignment requirement
|
||||
// K-dim: multiplier of chunk size
|
||||
|
||||
// TensorA Compressed
|
||||
// In unit of ElementARaw, after alignment requirement
|
||||
// M-dim: TMA alignment
|
||||
// K-dim: TMA alignment
|
||||
const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical();
|
||||
const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
// TensorE
|
||||
// In unit of ElementE (uint8_t), after alignment requirement
|
||||
// M-dim: TensorEAtom_M alignment
|
||||
// K-dim: TensorEAtom_K alignment
|
||||
const int GemmMAlignedE = compressor_utility.get_metadata_m_physical();
|
||||
const int GemmKAlignedE = compressor_utility.get_metadata_k_physical();
|
||||
|
||||
auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK);
|
||||
auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE);
|
||||
auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC);
|
||||
|
||||
typename LayoutATag::Stride stride_factor_A;
|
||||
typename LayoutETag::Stride stride_factor_E;
|
||||
|
||||
datas.tensor_A.resize(a_coord,
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutATag>::layout_factory(a_coord, stride_factor_A));
|
||||
datas.tensor_A_Comp.resize(a_comp_coord,
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutATag>::layout_factory(a_comp_coord, stride_factor_A));
|
||||
datas.tensor_A_Comp_ref.resize(a_comp_coord,
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutATag>::layout_factory(a_comp_coord, stride_factor_A),
|
||||
false);
|
||||
datas.tensor_E.resize(e_coord,
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutETag>::layout_factory(e_coord, stride_factor_E));
|
||||
datas.tensor_E_ref.resize(e_coord,
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutETag>::layout_factory(e_coord, stride_factor_E),
|
||||
false);
|
||||
|
||||
EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1));
|
||||
EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2));
|
||||
EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3));
|
||||
EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4));
|
||||
EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5));
|
||||
|
||||
compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6);
|
||||
|
||||
// Check for failed devide
|
||||
CUDA_CHECK_FALSE(cudaGetLastError());
|
||||
|
||||
datas.tensor_A.sync_device();
|
||||
datas.tensor_A_Comp.sync_device();
|
||||
datas.tensor_E.sync_device();
|
||||
|
||||
// Check for failed devide
|
||||
CUDA_CHECK_FALSE(cudaGetLastError());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr)
|
||||
{
|
||||
CudaRAII cuda_raii;
|
||||
|
||||
const int GemmM = cute::size<0>(problem_shape_MNKL);
|
||||
const int GemmN = cute::size<1>(problem_shape_MNKL);
|
||||
const int GemmK = cute::size<2>(problem_shape_MNKL);
|
||||
const int GemmL = cute::size<3>(problem_shape_MNKL);
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL));
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
{GemmM, GemmN, GemmK, GemmL},
|
||||
{datas.tensor_A.device_data(),
|
||||
stride_a,
|
||||
datas.tensor_A_Comp.device_data(),
|
||||
datas.tensor_E.device_data()},
|
||||
{hw_info}
|
||||
};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status {cutlass::Status::kSuccess };
|
||||
|
||||
status = compressor_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
CUDA_CHECK_FALSE(cudaGetLastError());
|
||||
}
|
||||
|
||||
status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
CUDA_CHECK_FALSE(cudaGetLastError());
|
||||
}
|
||||
|
||||
CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream));
|
||||
CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream));
|
||||
|
||||
status = compressor_op.run(cuda_raii.stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
CUDA_CHECK_FALSE(cudaGetLastError());
|
||||
}
|
||||
|
||||
CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream));
|
||||
CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop));
|
||||
CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream));
|
||||
if ( time != nullptr ){
|
||||
CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop));
|
||||
}
|
||||
|
||||
datas.tensor_A_Comp.sync_host();
|
||||
datas.tensor_E.sync_host();
|
||||
|
||||
#if 0
|
||||
{
|
||||
printf("\n--> DEVICE OUTPUT\n");
|
||||
printf("datas.tensor_A\n");
|
||||
std::cout << datas.tensor_A.host_view() << std::endl << std::endl;
|
||||
printf("datas.tensor_A_Comp\n");
|
||||
std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl;
|
||||
printf("datas.tensor_E\n");
|
||||
std::cout << datas.tensor_E.host_view() << std::endl << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas)
|
||||
{
|
||||
const int GemmM = cute::size<0>(problem_shape_MNKL);
|
||||
const int GemmN = cute::size<1>(problem_shape_MNKL);
|
||||
const int GemmK = cute::size<2>(problem_shape_MNKL);
|
||||
const int GemmL = cute::size<3>(problem_shape_MNKL);
|
||||
|
||||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL));
|
||||
|
||||
typename CompressorKernelHost::Arguments arguments{
|
||||
{GemmM, GemmN, GemmK, GemmL},
|
||||
{datas.tensor_A.host_data(),
|
||||
stride_a,
|
||||
datas.tensor_A_Comp_ref.host_data(),
|
||||
datas.tensor_E_ref.host_data()},
|
||||
{}};
|
||||
|
||||
const auto can_imp = CompressorKernelHost::can_implement(arguments);
|
||||
if (can_imp != cutlass::Status::kSuccess) {
|
||||
printf("can_implement() check failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Relies on std::vector for RAII
|
||||
auto workspace_size =
|
||||
static_cast<std::vector<uint8_t>::size_type>(CompressorKernelHost::get_workspace_size(arguments));
|
||||
std::vector<uint8_t> workspace_vector(workspace_size);
|
||||
auto workspace = static_cast<void*>(workspace_vector.data());
|
||||
|
||||
cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
printf("initialize_workspace() failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace);
|
||||
CompressorKernelHost::run(params);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compare_reference(Data& datas)
|
||||
{
|
||||
bool check_tensor_a_compressed =
|
||||
cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view());
|
||||
if (!check_tensor_a_compressed) {
|
||||
printf("A-Compressed Mismatch\n");
|
||||
}
|
||||
|
||||
bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view());
|
||||
if (!check_tensor_e) {
|
||||
printf("E Mismatch\n");
|
||||
}
|
||||
|
||||
return check_tensor_a_compressed && check_tensor_e;
|
||||
}
|
||||
|
||||
bool run_auto_small()
|
||||
{
|
||||
return run_auto(true);
|
||||
}
|
||||
|
||||
bool run_auto(bool run_small = false)
|
||||
{
|
||||
constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
|
||||
constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
|
||||
constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
|
||||
|
||||
constexpr int GemmN = 1;
|
||||
|
||||
using ProblemType = typename std::array<int, 4>;
|
||||
|
||||
std::vector<ProblemType> problems;
|
||||
|
||||
const std::vector<ProblemType> problems_multiplier_of_tensor_e_atom = {
|
||||
// * Regular Cases (multiplier of TensorEAlignment)
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1},
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1},
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1},
|
||||
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1},
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1},
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1},
|
||||
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1},
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1},
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1},
|
||||
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2},
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2},
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2},
|
||||
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2},
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2},
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2},
|
||||
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2},
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2},
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2},
|
||||
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3},
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3},
|
||||
{TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3},
|
||||
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3},
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3},
|
||||
{TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3},
|
||||
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3},
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3},
|
||||
{TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3},
|
||||
};
|
||||
|
||||
const std::vector<ProblemType> problems_multiplier_of_tensor_e_atom_large = {
|
||||
// * Large Case (multiplier of TensorEAlignment)
|
||||
{TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1},
|
||||
// {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2},
|
||||
// {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3},
|
||||
};
|
||||
|
||||
const std::vector<ProblemType> problems_multiplier_of_twochunk {
|
||||
// * Corner Cases
|
||||
{4, GemmN, LogicalElemsAPerChunk * 2, 1},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 4, 1},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 6, 1},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
|
||||
|
||||
{4, GemmN, LogicalElemsAPerChunk * 2, 2},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 4, 2},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 6, 2},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
|
||||
|
||||
{4, GemmN, LogicalElemsAPerChunk * 2, 3},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 4, 3},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 6, 3},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
|
||||
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
|
||||
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
|
||||
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
|
||||
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
|
||||
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
|
||||
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
|
||||
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1},
|
||||
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2},
|
||||
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3},
|
||||
};
|
||||
|
||||
const std::vector<ProblemType> problems_multiplier_of_onechunk {
|
||||
{4, GemmN, LogicalElemsAPerChunk * 1, 1},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 3, 1},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 5, 1},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
|
||||
|
||||
{4, GemmN, LogicalElemsAPerChunk * 1, 2},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 3, 2},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 5, 2},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
|
||||
|
||||
{4, GemmN, LogicalElemsAPerChunk * 1, 3},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 3, 3},
|
||||
{4, GemmN, LogicalElemsAPerChunk * 5, 3},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
|
||||
{4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
|
||||
{4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
|
||||
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
|
||||
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
|
||||
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3},
|
||||
{32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
|
||||
{32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
|
||||
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
|
||||
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
|
||||
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
|
||||
{TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
|
||||
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1},
|
||||
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2},
|
||||
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3},
|
||||
{TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3},
|
||||
};
|
||||
|
||||
// Run small only run multiplier of chunk size cases
|
||||
if (run_small) {
|
||||
problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end());
|
||||
}
|
||||
// Run full run all corner cases
|
||||
else {
|
||||
problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end());
|
||||
problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end());
|
||||
problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end());
|
||||
problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end());
|
||||
}
|
||||
|
||||
for (const auto& problem_shape_MNKL : problems) {
|
||||
const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL;
|
||||
bool passed = run({GemmM, GemmN, GemmK, GemmL});
|
||||
printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL");
|
||||
CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL");
|
||||
if (not passed) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run(ProblemShapeType problem_shape_MNKL)
|
||||
{
|
||||
// Check if valid test
|
||||
if (not valid_test(problem_shape_MNKL)) {
|
||||
CUTLASS_TRACE_HOST("valid_test() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Data Storage
|
||||
Data datas;
|
||||
|
||||
// Initialize Data
|
||||
if (not initialize(problem_shape_MNKL, datas)) {
|
||||
CUTLASS_TRACE_HOST("initialize() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run Compressor (Host Ref)
|
||||
if (not run_host_ref(problem_shape_MNKL, datas)) {
|
||||
CUTLASS_TRACE_HOST("run_host() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run Compressor (Device)
|
||||
if (not run_device(problem_shape_MNKL, datas)) {
|
||||
CUTLASS_TRACE_HOST("run_device() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify
|
||||
if (not compare_reference(datas)) {
|
||||
CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n");
|
||||
printf("compare_reference() DEVICE <-> LEGACY HOST fail\n");
|
||||
return false;
|
||||
}
|
||||
// else {
|
||||
// printf("DEVICE <-> HOST PASS\n");
|
||||
// }
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool benchmark(ProblemShapeType problem_shape_MNKL) {
|
||||
const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL;
|
||||
printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL);
|
||||
|
||||
// Check if valid test
|
||||
if (valid_test(problem_shape_MNKL) == false) {
|
||||
CUTLASS_TRACE_HOST("valid_test() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2 warm-up iterations and 10 timing iterations
|
||||
constexpr int num_warmup = 5;
|
||||
constexpr int num_iter = 10;
|
||||
|
||||
// Duplicate data to mimic cold cache
|
||||
Data data[num_warmup + num_iter];
|
||||
double total_time_milliseconds{0.0};
|
||||
|
||||
for (int i = 0; i < num_warmup + num_iter; ++i ) {
|
||||
printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i );
|
||||
|
||||
auto& datum_i = data[i];
|
||||
|
||||
// Initialize Data
|
||||
if (initialize(problem_shape_MNKL, datum_i) == false) {
|
||||
CUTLASS_TRACE_HOST("initialize() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run Compressor (Device)
|
||||
double time_i_milliseconds{0.0f};
|
||||
if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) {
|
||||
CUTLASS_TRACE_HOST("run_device() fail\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if ( i >= num_warmup ) {
|
||||
total_time_milliseconds += time_i_milliseconds;
|
||||
}
|
||||
}
|
||||
|
||||
const double mean_time_milliseconds = total_time_milliseconds / num_iter;
|
||||
printf("Mean time (ms): %.5f\n", mean_time_milliseconds);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
// Data Init Setting
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_A_Comp;
|
||||
cutlass::Distribution::Kind init_E;
|
||||
uint64_t seed;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace transform
|
||||
} // namespace test
|
||||
Reference in New Issue
Block a user