Adding more Threadblock Tiles for Mixed-input TensorOp (BF16 * S8) in cutlass_library (#1132)

* Adding more tiles in the cutlass_library for mixed-input support.

* fix rebase issue

* more tiles to upcast a
This commit is contained in:
Manish Gupta
2023-10-13 08:33:15 -07:00
committed by GitHub
parent fa8dfe631f
commit 757275f279
14 changed files with 830 additions and 20 deletions

View File

@ -350,10 +350,12 @@ cutlass_test_unit_add_executable(
# Upcast on Operand A
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
# Upcast on Operand B
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
)
cutlass_test_unit_add_executable(

View File

@ -0,0 +1,278 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_universal.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) {
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 128x128x32_64x64x32) {
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 64x128x32_32x64x32) {
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 128x64x32_64x32x32) {
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 64x64x32_32x32x32) {
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 16x128x32_16x64x32) {
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 128, 32>,
cutlass::gemm::GemmShape<16, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -56,7 +56,7 @@
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
TEST(SM80_Device_GemmUniversal_f16t_s8n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
using ElementA = cutlass::half_t;
using ElementB = int8_t;

View File

@ -0,0 +1,384 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_universal.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x128x32_64x64x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 64x128x32_32x64x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 64x64x32_32x32x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x64x32_64x32x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x64x32_64x64x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x32x32_64x32x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 32, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
8, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16x32_64x16x32) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 16, 32>,
cutlass::gemm::GemmShape<64, 16, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 4,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16x64_64x16x64) {
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 16, 64>,
cutlass::gemm::GemmShape<64, 16, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 4,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -56,7 +56,7 @@
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
TEST(SM80_Device_GemmUniversal_s8t_f16n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
using ElementA = int8_t;
using ElementB = cutlass::half_t;

View File

@ -75,7 +75,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_
.run();
}
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -140,7 +139,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
@ -227,6 +225,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
@ -251,7 +250,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
/// F32 <= U8 * BF16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
@ -297,7 +296,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
/// F32 <= I8 * BF16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;