Allow per-column bias in EpilogueTensorBroadcast (#1275)
* Allow per-column bias in EpilogueTensorBroadcast EpilogueTensorBroadcast only supports per-row vector broadcast, because the bias stride is hardcoded. It can easily support both if the bias stride is made conditional, and the original behavior is maintained by defaulting to per-row. * Add unit test for EpilogueTensorBroadcast with per-col bias --------- Co-authored-by: Ali Hassani <ahassanijr@gmail.com> Co-authored-by: Ali Hassani <ali@hippoml.com>
This commit is contained in:
@ -76,6 +76,8 @@ struct Testbed3xTensorBroadcast {
|
||||
static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled;
|
||||
static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled;
|
||||
|
||||
static constexpr bool PerColBias = Epilogue::PerColumnBias;
|
||||
|
||||
using LayoutTagA = typename TestBedImpl::LayoutTagA;
|
||||
using LayoutTagB = typename TestBedImpl::LayoutTagB;
|
||||
using LayoutTagC = typename TestBedImpl::LayoutTagC;
|
||||
@ -130,8 +132,8 @@ struct Testbed3xTensorBroadcast {
|
||||
|
||||
void initialize_bias(ProblemShapeType problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto M = cute::get<0>(problem_shape_MNKL);
|
||||
bias.resize(cutlass::Coord<1>(M));
|
||||
auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL);
|
||||
bias.resize(cutlass::Coord<1>(bias_size));
|
||||
|
||||
EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023));
|
||||
bias.sync_device();
|
||||
@ -186,7 +188,8 @@ struct Testbed3xTensorBroadcast {
|
||||
std::ofstream file(fname.str());
|
||||
file
|
||||
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
|
||||
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias << "\n\n";
|
||||
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias
|
||||
<< ", per-col bias: " << PerColBias << "\n\n";
|
||||
|
||||
if (use_bias){
|
||||
file << "Bias = \n" << bias.host_view()<< "\n\n";
|
||||
@ -225,7 +228,7 @@ struct Testbed3xTensorBroadcast {
|
||||
auto D = cute::make_tensor(impl_.reference_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d));
|
||||
auto Bias = cute::make_tensor(static_cast<ElementBias*>(use_bias ? bias.host_data() : nullptr),
|
||||
cute::make_layout(cute::make_shape(M, 1)));
|
||||
cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1)));
|
||||
auto C0 = cute::make_tensor(impl_.tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c));
|
||||
auto C1 = cute::make_tensor(tensor_C1.host_data(),
|
||||
@ -263,7 +266,9 @@ struct Testbed3xTensorBroadcast {
|
||||
decltype(dummy_Aux),
|
||||
decltype(dummy_Valpha),
|
||||
decltype(dummy_Vbeta),
|
||||
ActivationFunctor> epilogue_params{
|
||||
ActivationFunctor,
|
||||
cutlass::plus<ElementCompute>,
|
||||
PerColBias> epilogue_params{
|
||||
alpha,
|
||||
dummy_beta,
|
||||
dummy_C,
|
||||
|
||||
@ -97,6 +97,54 @@ TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128
|
||||
EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast<Gemm>());
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128x32_1x2x1_ActReLU_Bin0Mul_Bin1Plus_UnaryHardSwish_PerColBias) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = ElementOutput;
|
||||
using ElementCompute = ElementOutput;
|
||||
using ElementBias = ElementOutput;
|
||||
|
||||
using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
float, LayoutA, 4,
|
||||
float, LayoutB, 4,
|
||||
float,
|
||||
Shape<_64,_128,_128>, Shape<_1,_2,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
cutlass::epilogue::collective::EpilogueTensorBroadcast<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::epilogue::thread::LinearCombinationTensorBroadcast<
|
||||
ElementOutput, ElementAccumulator, ElementCompute, ElementBias,
|
||||
cutlass::epilogue::thread::ReLu,
|
||||
cutlass::multiplies,
|
||||
cutlass::plus,
|
||||
cutlass::epilogue::thread::HardSwish
|
||||
>,
|
||||
cutlass::gemm::EpilogueDefault,
|
||||
/* PerColBias = */ true>>;
|
||||
|
||||
EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled);
|
||||
EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled);
|
||||
EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled);
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveOp,
|
||||
EpilogueOp
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast<Gemm>());
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
Reference in New Issue
Block a user