Add support for sparse GEMM with row broadcasted bias vector (#951)
This commit is contained in:
committed by
GitHub
parent
b4ab501767
commit
d3e72719b4
@ -37,6 +37,7 @@
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse_row_broadcast.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
@ -267,6 +268,24 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128)
|
||||
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM80_Device_Sparse_Gemm_Row_Broadcast_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) {
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Gemm = cutlass::gemm::device::SparseGemmRowBroadcast<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
|
||||
|
||||
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>(true));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)
|
||||
|
||||
@ -164,14 +164,19 @@ struct SparseTestbed {
|
||||
}
|
||||
|
||||
/// Initializes data structures
|
||||
void initialize(cutlass::gemm::GemmCoord problem_size) {
|
||||
void initialize(cutlass::gemm::GemmCoord problem_size, bool tensor_C_row_broadcast = false) {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse));
|
||||
tensor_A_uncompressed.resize(problem_size.mk());
|
||||
tensor_B.resize(problem_size.kn());
|
||||
tensor_C.resize(problem_size.mn());
|
||||
if (tensor_C_row_broadcast) {
|
||||
tensor_C.resize({problem_size.m(), 1});
|
||||
}
|
||||
else {
|
||||
tensor_C.resize(problem_size.mn());
|
||||
}
|
||||
tensor_D.resize(problem_size.mn());
|
||||
reference_D.resize(problem_size.mn(), false);
|
||||
tensor_E.resize(cutlass::make_Coord(
|
||||
@ -206,7 +211,14 @@ struct SparseTestbed {
|
||||
tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1);
|
||||
tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1);
|
||||
|
||||
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
||||
if (tensor_C_row_broadcast) {
|
||||
for (int i = 0; i < problem_size.m(); ++i)
|
||||
for (int j = 0; j < problem_size.n(); ++j)
|
||||
reference_D.host_view().at({i, j}) = tensor_C.host_view().at({i, 0});
|
||||
}
|
||||
else {
|
||||
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
||||
}
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
@ -338,7 +350,8 @@ struct SparseTestbed {
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
int split_k_slices = 1,
|
||||
ElementCompute alpha = ElementCompute(1),
|
||||
ElementCompute beta = ElementCompute(0)) {
|
||||
ElementCompute beta = ElementCompute(0),
|
||||
bool tensor_C_row_broadcast = false) {
|
||||
|
||||
// Waive test if insufficient CUDA device
|
||||
if (!sufficient()) {
|
||||
@ -348,7 +361,7 @@ struct SparseTestbed {
|
||||
return true;
|
||||
}
|
||||
|
||||
this->initialize(problem_size);
|
||||
this->initialize(problem_size, tensor_C_row_broadcast);
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
@ -403,7 +416,7 @@ struct SparseTestbed {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm>
|
||||
bool TestAllSparseGemm() {
|
||||
bool TestAllSparseGemm(bool tensor_C_row_broadcast = false) {
|
||||
bool passed = true;
|
||||
|
||||
int const kMinimumOperandElementSize =
|
||||
@ -463,7 +476,8 @@ bool TestAllSparseGemm() {
|
||||
problem_size,
|
||||
split_k,
|
||||
cutlass::from_real<ElementCompute>(alpha),
|
||||
cutlass::from_real<ElementCompute>(beta)
|
||||
cutlass::from_real<ElementCompute>(beta),
|
||||
tensor_C_row_broadcast
|
||||
);
|
||||
|
||||
if (!passed) {
|
||||
|
||||
Reference in New Issue
Block a user