Groupwise scaling along M for FP8 gemm (#2037)

* FP8 groupwise scaling along M

* small updates

---------

Co-authored-by: zl <zl@deepseek.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Liang
2025-02-01 02:51:28 +08:00
committed by GitHub
parent bdd641790a
commit 3c28697b9f
8 changed files with 1418 additions and 48 deletions

View File

@ -123,7 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

View File

@ -0,0 +1,770 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster.
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
$ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
--m=2816 --n=3072 --k=16384 \
--save_aux=false --save_amax=false \
--device_scale=false --raster=h --swizzle=2
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
// Includes from examples directory
#include "helper.h"
#include "hopper_fp8_commandline.hpp"
#include "reference/host/gemm_with_groupwise_scaling.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// Auxiliary matrix configuration and other fusion types
using ElementAux = ElementC;
using LayoutAux = LayoutC;
using ElementAmax = float;
using ElementBias = float;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementBlockScale = float; // Element type for blockscaling during accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
constexpr int ScaleMsPerTile = 2;
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopWithBlockWiseScaling,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideAux = StrideD;
constexpr bool IsDFp8 =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
constexpr bool IsAuxFp8 =
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
"FP8 scaling granularity must evenly divide tile shape along M.");
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
"ElementAccumulator and ElementBlockScale should be same datatype");
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
StrideAux stride_aux;
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
uint32_t mma_promotion_interval;
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_B;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_C;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_D;
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_aux;
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_D;
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_D;
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_aux;
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -1;
scope_max = 1;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options<RasterOrderOptions> &options) {
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
stride_aux = stride_D;
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k);
auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l);
tensor_A.resize(a_coord);
blockscale_tensor_A.resize(groupscale_a_coord);
tensor_B.resize(b_coord);
blockscale_tensor_B.resize(blockscale_b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
cutlass::Distribution::Kind dist_A = cutlass::Distribution::Uniform;
cutlass::Distribution::Kind dist_B = cutlass::Distribution::Uniform;
cutlass::Distribution::Kind dist_C = cutlass::Distribution::Uniform;
cutlass::Distribution::Kind dist_scaleA = cutlass::Distribution::Uniform;
cutlass::Distribution::Kind dist_scaleB = cutlass::Distribution::Uniform;
initialize_tensor(tensor_A.host_view(), dist_A, seed + 2022);
initialize_tensor(tensor_B.host_view(), dist_B, seed + 2023);
initialize_tensor(tensor_C.host_view(), dist_C, seed + 2024);
initialize_scale_tensor(blockscale_tensor_A.host_view(), dist_scaleA, seed + 2025);
initialize_scale_tensor(blockscale_tensor_B.host_view(), dist_scaleB, seed + 2026);
#if 0 // Dump blockscaled tensors
std::cout << "blockscale_tensor_A: " << groupscale_a_coord << std::endl;
std::cout << blockscale_tensor_A.host_view() << "\n";
std::cout << "blockscale_tensor_B: " << blockscale_b_coord << std::endl;
std::cout << blockscale_tensor_B.host_view() << "\n";
#endif
// Print group scaling tensors on the host side.
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
blockscale_tensor_A.sync_device();
blockscale_tensor_B.sync_device();
mma_promotion_interval = 4;
if (options.save_aux) {
tensor_aux.resize(c_coord);
tensor_aux.sync_device();
tensor_ref_aux.resize(c_coord);
}
if (options.device_scale) {
scalar_alpha.resize(cutlass::make_Coord(1));
scalar_beta.resize(cutlass::make_Coord(1));
scale_A.resize(cutlass::make_Coord(1));
scale_B.resize(cutlass::make_Coord(1));
scale_C.resize(cutlass::make_Coord(1));
scale_D.resize(cutlass::make_Coord(1));
scale_aux.resize(cutlass::make_Coord(1));
cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha);
cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta);
cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a);
cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b);
cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c);
cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d);
cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux);
scalar_alpha.sync_device();
scalar_beta.sync_device();
scale_A.sync_device();
scale_B.sync_device();
scale_C.sync_device();
scale_D.sync_device();
scale_aux.sync_device();
}
if (IsDFp8 && options.save_amax) {
abs_max_D.resize(cutlass::make_Coord(1));
abs_max_D.sync_device();
reference_abs_max_D.resize(cutlass::make_Coord(1));
}
if (IsAuxFp8 && options.save_aux && options.save_amax) {
abs_max_aux.resize(cutlass::make_Coord(1));
abs_max_aux.sync_device();
reference_abs_max_aux.resize(cutlass::make_Coord(1));
}
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(),
stride_A,
tensor_B.device_data(),
stride_B,
mma_promotion_interval,
blockscale_tensor_A.device_data(),
blockscale_tensor_B.device_data()
},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = scalar_alpha.device_data();
fusion_args.beta_ptr = scalar_beta.device_data();
fusion_args.scale_a = options.scale_a;
fusion_args.scale_b = options.scale_b;
fusion_args.scale_c = options.scale_c;
fusion_args.scale_a_ptr = scale_A.device_data();
fusion_args.scale_b_ptr = scale_B.device_data();
fusion_args.scale_c_ptr = scale_C.device_data();
// ignored if tensor types are not fp8
fusion_args.scale_d = options.scale_d;
fusion_args.scale_aux = options.scale_aux;
fusion_args.scale_d_ptr = scale_D.device_data();
fusion_args.scale_aux_ptr = scale_aux.device_data();
// leaving/setting these as nullptr disables the fusion at runtime
fusion_args.bias_ptr = nullptr;
if (options.save_aux) {
fusion_args.aux_ptr = tensor_aux.device_data();
fusion_args.dAux = stride_aux;
if (options.save_amax) {
fusion_args.amax_aux_ptr = abs_max_aux.device_data();
}
}
if (options.save_amax) {
fusion_args.amax_D_ptr = abs_max_D.device_data();
}
arguments.scheduler.raster_order = options.raster;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
bool verify(const Options<RasterOrderOptions> &options) {
//
// Compute reference output
//
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(
cute::make_shape(options.m, options.k, options.l),
stride_A
)
);
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(
cute::make_shape(options.n, options.k, options.l),
stride_B
)
);
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(
cute::make_shape(options.m, options.n, options.l),
stride_C
)
);
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(
cute::make_shape(options.m, options.n, options.l),
stride_D
)
);
auto Aux = cute::make_tensor(tensor_ref_aux.host_data(),
cute::make_layout(
cute::make_shape(options.m, options.n, options.l),
stride_aux
)
);
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
)
);
using unused_t = decltype(D);
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
decltype(A), decltype(B),
decltype(blockscale_A), decltype(blockscale_B),
TileShape> mainloop_params{
A, B, // Operand Tensors
blockscale_A, blockscale_B // Groupwise scaling Tensors
};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
decltype(Aux),
unused_t, // valpha
unused_t, // vbeta
ActivationFunctor
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.Aux = Aux;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
epilogue_params.scale_a = options.scale_a;
epilogue_params.scale_b = options.scale_b;
epilogue_params.scale_c = options.scale_c;
epilogue_params.scale_d = options.scale_d;
epilogue_params.scale_aux = options.scale_aux;
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data();
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
if (false) {
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
}
if (IsDFp8 && options.save_amax) {
abs_max_D.sync_host();
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
}
if (options.save_aux) {
tensor_aux.sync_host();
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
if (IsAuxFp8 && options.save_amax) {
abs_max_aux.sync_host();
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
}
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options<RasterOrderOptions> &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
// if (!result.passed) {
// exit(-1);
// }
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options<RasterOrderOptions> options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -30,3 +30,8 @@ cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
)
cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
)

View File

@ -0,0 +1,507 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for GETT in host-side code.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/gemm.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/relatively_equal.h"
#include <iostream>
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::reference::host {
template<class T, class = void>
struct ElementTraits {
using type = T;
};
template<class T>
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
using type = decltype(std::declval<T>().get());
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_, // (N, K, L)
class TensorScaleA_, // (m, k, L)
class TensorScaleB_, // (n, k, L)
class TileShape_
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
using TensorA = TensorA_;
using TensorB = TensorB_;
using EngineA = typename TensorA::engine_type;
using LayoutA = typename TensorA::layout_type;
using EngineB = typename TensorB::engine_type;
using LayoutB = typename TensorB::layout_type;
using TensorScaleA = TensorScaleA_;
using TensorScaleB = TensorScaleB_;
using TileShape = TileShape_;
using EngineScaleA = typename TensorScaleA::engine_type;
using EngineScaleB = typename TensorScaleB::engine_type;
TensorA A{};
TensorB B{};
TensorScaleA ScaleA{};
TensorScaleB ScaleB{};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = TensorD_, // (M, 1)
class TensorAux_ = TensorD_, // (M, N, L)
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
using ElementScalingFactor = ElementScalingFactor_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using TensorC = TensorC_;
using TensorD = TensorD_;
using TensorAux = TensorAux_;
using VectorBias = VectorBias_;
using VectorAlpha = VectorAlpha_;
using VectorBeta = VectorBeta_;
using ActivationFunctor = ActivationFunctor_;
using BiasBinaryOp = BiasBinaryOp_;
using EngineC = typename TensorC::engine_type;
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
TensorC C{};
TensorD D{};
VectorBias Bias{};
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
ElementScalingFactor scale_a = ElementScalingFactor(1);
ElementScalingFactor scale_b = ElementScalingFactor(1);
ElementScalingFactor scale_c = ElementScalingFactor(1);
ElementScalingFactor scale_d = ElementScalingFactor(1);
ElementScalingFactor scale_aux = ElementScalingFactor(1);
bool beta_per_channel_scaling = false;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling
template <
class MainloopParams,
class EpilogueParams
>
void Gett(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
#if defined(_OPENMP)
#pragma omp parallel for collapse(3)
#endif
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
gett_mainloop(mainloop_params, m, n, l, acc);
gett_epilogue(epilogue_params, m, n, l, acc);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Mainloop
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_mainloop(
MainloopParams const& mainloop_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
multiplies<ElementAccumulator> scale_op;
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
// Tempo accumulators to seperate blockwise accumulation
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
// Zero out accumulators
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
int64_t block_m = m / kBlockM;
int64_t block_n = n / kBlockN;
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l);
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
// Load Blockwise scaling factor from blockscale Tensors for B
int64_t block_k = k / kBlockK;
cute::Tensor scale_a = blockscale_A(_, block_k);
ElementBlockScaleB scale_b = blockscale_B[block_k];
// Load A
ElementAccumulator a_frag[kBlockM];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
} else {
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// Load B
ElementAccumulator b_frag[kBlockN];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
} else {
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
}
}
// Apply Groupwise-scaling at kBlockK boundary
// (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
// (b) Zero-out partial temporary (acc_temp),
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a[m_b / ScaleGranularityM] * scale_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Epilogue
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_epilogue(
EpilogueParams const& epilogue_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
using ElementAux = typename EpilogueParams::TensorAux::value_type;
using ElementBias = typename EpilogueParams::VectorBias::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
constexpr bool IsReLUAuxNeeded =
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
constexpr bool IsClamp =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
constexpr bool IsBackpropFusion =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
// Input related converter
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
NumericConverter<ElementCompute, ElementC> source_converter;
NumericConverter<ElementCompute, ElementBias> bias_converter;
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
// Scale related converter
NumericConverter<ElementCompute, ElementScalar> scale_converter;
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
// Abs max converter
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
// Output related converter
NumericConverter<ElementD, ElementCompute> destination_converter;
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
NumericConverter<ElementBias, ElementCompute> dBias_converter;
// Epilogue operations
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
multiplies<ElementCompute> mul;
plus<ElementCompute> add;
// Activation operation
ActivationFunctor activation;
// Bias binary operation
BiasBinaryOp bias_op;
// Do conversion
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
// Init local var
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
converted_beta = mul(converted_beta, converted_scale_c);
ElementCompute inter_accum[kBlockM][kBlockN];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
ElementCompute local_dBias = ElementCompute(0);
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
// per-row alpha
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
}
ElementCompute output = mul(converted_alpha, converted_acc);
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
output = bias_op(output, converted_bias);
}
if (raw_pointer_cast(epilogue_params.C.data())) {
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
// per-row beta
if (epilogue_params.Vbeta.data()) {
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
}
output = epilogue_fma(converted_beta, converted_src, output);
}
if constexpr (IsBackpropFusion) {
ElementAux aux_input = ElementAux(0);
if (raw_pointer_cast(epilogue_params.Aux.data())) {
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
}
output = activation(output, aux_source_converter(aux_input));
local_dBias = add(local_dBias, output);
}
else {
if (raw_pointer_cast(epilogue_params.Aux.data())) {
auto aux_output = output;
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
}
if constexpr (IsReLUAuxNeeded) {
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
} else {
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
}
}
if constexpr (IsClamp) { // Treat Clamp as ReLU
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
}
else {
output = activation(output);
}
}
if constexpr (IsScalingAndAmaxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_output = amax_op(local_abs_max_output, output);
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
}
inter_accum[m_b][n_b] = ElementCompute(output);
}
} // n_b
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
local_dBias = add(local_dBias, converted_dBias);
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
}
}
} // m_b
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
}
}
}
#if defined(_OPENMP)
#pragma omp critical(Abs_Max_Data_Update)
#endif
{
if constexpr (IsScalingAndAmaxOutputNeeded) {
if (epilogue_params.abs_max_D) {
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
}
}
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
if (epilogue_params.abs_max_Aux) {
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM - General Matrix-Matrix contraction without conjugation options
template <
class MainloopParams,
class EpilogueParams
>
void Gemm3x(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
using namespace cute;
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
"with Batchmode are supported");
// Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett).
Gett(mainloop_params, epilogue_params);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // cutlass::reference::host
/////////////////////////////////////////////////////////////////////////////////////////////////