Add support for sparse GEMM with row broadcasted bias vector (#951)

This commit is contained in:
Aleksandar Samardžić
2023-05-24 16:25:05 +02:00
committed by GitHub
parent b4ab501767
commit d3e72719b4
7 changed files with 1847 additions and 7 deletions

View File

@ -37,6 +37,7 @@
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/device/gemm_sparse_row_broadcast.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
@ -267,6 +268,24 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128)
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_Row_Broadcast_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::SparseGemmRowBroadcast<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>(true));
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)

View File

@ -164,14 +164,19 @@ struct SparseTestbed {
}
/// Initializes data structures
void initialize(cutlass::gemm::GemmCoord problem_size) {
void initialize(cutlass::gemm::GemmCoord problem_size, bool tensor_C_row_broadcast = false) {
//
// Allocate the GEMM workspace
//
tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse));
tensor_A_uncompressed.resize(problem_size.mk());
tensor_B.resize(problem_size.kn());
tensor_C.resize(problem_size.mn());
if (tensor_C_row_broadcast) {
tensor_C.resize({problem_size.m(), 1});
}
else {
tensor_C.resize(problem_size.mn());
}
tensor_D.resize(problem_size.mn());
reference_D.resize(problem_size.mn(), false);
tensor_E.resize(cutlass::make_Coord(
@ -206,7 +211,14 @@ struct SparseTestbed {
tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1);
tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1);
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
if (tensor_C_row_broadcast) {
for (int i = 0; i < problem_size.m(); ++i)
for (int j = 0; j < problem_size.n(); ++j)
reference_D.host_view().at({i, j}) = tensor_C.host_view().at({i, 0});
}
else {
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
}
tensor_A.sync_device();
tensor_B.sync_device();
@ -338,7 +350,8 @@ struct SparseTestbed {
cutlass::gemm::GemmCoord problem_size,
int split_k_slices = 1,
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0)) {
ElementCompute beta = ElementCompute(0),
bool tensor_C_row_broadcast = false) {
// Waive test if insufficient CUDA device
if (!sufficient()) {
@ -348,7 +361,7 @@ struct SparseTestbed {
return true;
}
this->initialize(problem_size);
this->initialize(problem_size, tensor_C_row_broadcast);
//
// Initialize the GEMM operator
@ -403,7 +416,7 @@ struct SparseTestbed {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
bool TestAllSparseGemm() {
bool TestAllSparseGemm(bool tensor_C_row_broadcast = false) {
bool passed = true;
int const kMinimumOperandElementSize =
@ -463,7 +476,8 @@ bool TestAllSparseGemm() {
problem_size,
split_k,
cutlass::from_real<ElementCompute>(alpha),
cutlass::from_real<ElementCompute>(beta)
cutlass::from_real<ElementCompute>(beta),
tensor_C_row_broadcast
);
if (!passed) {