Groupwise scaling along M for FP8 gemm (#2037)
* FP8 groupwise scaling along M * small updates --------- Co-authored-by: zl <zl@deepseek.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -123,7 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
@ -0,0 +1,770 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster.
|
||||
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
|
||||
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
|
||||
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
|
||||
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
|
||||
--m=2816 --n=3072 --k=16384 \
|
||||
--save_aux=false --save_amax=false \
|
||||
--device_scale=false --raster=h --swizzle=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
#include "reference/host/gemm_with_groupwise_scaling.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Auxiliary matrix configuration and other fusion types
|
||||
using ElementAux = ElementC;
|
||||
using LayoutAux = LayoutC;
|
||||
using ElementAmax = float;
|
||||
using ElementBias = float;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int ScaleMsPerTile = 2;
|
||||
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopWithBlockWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideAux = StrideD;
|
||||
|
||||
constexpr bool IsDFp8 =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsAuxFp8 =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
|
||||
"FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_B;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_C;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_D;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
|
||||
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k);
|
||||
auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
blockscale_tensor_A.resize(groupscale_a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
blockscale_tensor_B.resize(blockscale_b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
cutlass::Distribution::Kind dist_A = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_B = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_C = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_scaleA = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_scaleB = cutlass::Distribution::Uniform;
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), dist_A, seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), dist_B, seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), dist_C, seed + 2024);
|
||||
initialize_scale_tensor(blockscale_tensor_A.host_view(), dist_scaleA, seed + 2025);
|
||||
initialize_scale_tensor(blockscale_tensor_B.host_view(), dist_scaleB, seed + 2026);
|
||||
|
||||
#if 0 // Dump blockscaled tensors
|
||||
std::cout << "blockscale_tensor_A: " << groupscale_a_coord << std::endl;
|
||||
std::cout << blockscale_tensor_A.host_view() << "\n";
|
||||
std::cout << "blockscale_tensor_B: " << blockscale_b_coord << std::endl;
|
||||
std::cout << blockscale_tensor_B.host_view() << "\n";
|
||||
#endif
|
||||
|
||||
// Print group scaling tensors on the host side.
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.resize(c_coord);
|
||||
tensor_aux.sync_device();
|
||||
tensor_ref_aux.resize(c_coord);
|
||||
}
|
||||
|
||||
if (options.device_scale) {
|
||||
scalar_alpha.resize(cutlass::make_Coord(1));
|
||||
scalar_beta.resize(cutlass::make_Coord(1));
|
||||
scale_A.resize(cutlass::make_Coord(1));
|
||||
scale_B.resize(cutlass::make_Coord(1));
|
||||
scale_C.resize(cutlass::make_Coord(1));
|
||||
scale_D.resize(cutlass::make_Coord(1));
|
||||
scale_aux.resize(cutlass::make_Coord(1));
|
||||
|
||||
cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha);
|
||||
cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta);
|
||||
cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a);
|
||||
cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b);
|
||||
cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c);
|
||||
cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d);
|
||||
cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux);
|
||||
|
||||
scalar_alpha.sync_device();
|
||||
scalar_beta.sync_device();
|
||||
scale_A.sync_device();
|
||||
scale_B.sync_device();
|
||||
scale_C.sync_device();
|
||||
scale_D.sync_device();
|
||||
scale_aux.sync_device();
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(),
|
||||
stride_A,
|
||||
tensor_B.device_data(),
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
blockscale_tensor_B.device_data()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = scalar_alpha.device_data();
|
||||
fusion_args.beta_ptr = scalar_beta.device_data();
|
||||
fusion_args.scale_a = options.scale_a;
|
||||
fusion_args.scale_b = options.scale_b;
|
||||
fusion_args.scale_c = options.scale_c;
|
||||
fusion_args.scale_a_ptr = scale_A.device_data();
|
||||
fusion_args.scale_b_ptr = scale_B.device_data();
|
||||
fusion_args.scale_c_ptr = scale_C.device_data();
|
||||
|
||||
// ignored if tensor types are not fp8
|
||||
fusion_args.scale_d = options.scale_d;
|
||||
fusion_args.scale_aux = options.scale_aux;
|
||||
fusion_args.scale_d_ptr = scale_D.device_data();
|
||||
fusion_args.scale_aux_ptr = scale_aux.device_data();
|
||||
|
||||
// leaving/setting these as nullptr disables the fusion at runtime
|
||||
fusion_args.bias_ptr = nullptr;
|
||||
|
||||
if (options.save_aux) {
|
||||
fusion_args.aux_ptr = tensor_aux.device_data();
|
||||
fusion_args.dAux = stride_aux;
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_aux_ptr = abs_max_aux.device_data();
|
||||
}
|
||||
}
|
||||
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options<RasterOrderOptions> &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.k, options.l),
|
||||
stride_A
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.n, options.k, options.l),
|
||||
stride_B
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.n, options.l),
|
||||
stride_C
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.n, options.l),
|
||||
stride_D
|
||||
)
|
||||
);
|
||||
auto Aux = cute::make_tensor(tensor_ref_aux.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.n, options.l),
|
||||
stride_aux
|
||||
)
|
||||
);
|
||||
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
|
||||
decltype(A), decltype(B),
|
||||
decltype(blockscale_A), decltype(blockscale_B),
|
||||
TileShape> mainloop_params{
|
||||
A, B, // Operand Tensors
|
||||
blockscale_A, blockscale_B // Groupwise scaling Tensors
|
||||
};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
decltype(Aux),
|
||||
unused_t, // valpha
|
||||
unused_t, // vbeta
|
||||
ActivationFunctor
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.Aux = Aux;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
epilogue_params.scale_a = options.scale_a;
|
||||
epilogue_params.scale_b = options.scale_b;
|
||||
epilogue_params.scale_c = options.scale_c;
|
||||
epilogue_params.scale_d = options.scale_d;
|
||||
epilogue_params.scale_aux = options.scale_aux;
|
||||
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
|
||||
epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data();
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
if (false) {
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
}
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options<RasterOrderOptions> &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -30,3 +30,8 @@ cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
|
||||
)
|
||||
@ -0,0 +1,507 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GETT in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/relatively_equal.h"
|
||||
#include <iostream>
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
|
||||
template<class T, class = void>
|
||||
struct ElementTraits {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
|
||||
using type = decltype(std::declval<T>().get());
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorB_, // (N, K, L)
|
||||
class TensorScaleA_, // (m, k, L)
|
||||
class TensorScaleB_, // (n, k, L)
|
||||
class TileShape_
|
||||
>
|
||||
struct GettMainloopParams {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TensorA = TensorA_;
|
||||
using TensorB = TensorB_;
|
||||
using EngineA = typename TensorA::engine_type;
|
||||
using LayoutA = typename TensorA::layout_type;
|
||||
using EngineB = typename TensorB::engine_type;
|
||||
using LayoutB = typename TensorB::layout_type;
|
||||
|
||||
using TensorScaleA = TensorScaleA_;
|
||||
using TensorScaleB = TensorScaleB_;
|
||||
using TileShape = TileShape_;
|
||||
using EngineScaleA = typename TensorScaleA::engine_type;
|
||||
using EngineScaleB = typename TensorScaleB::engine_type;
|
||||
|
||||
TensorA A{};
|
||||
TensorB B{};
|
||||
TensorScaleA ScaleA{};
|
||||
TensorScaleB ScaleB{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = TensorD_, // (M, 1)
|
||||
class TensorAux_ = TensorD_, // (M, N, L)
|
||||
class VectorAlpha_ = TensorD_, // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
||||
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
||||
bool PerColumnBias_ = false
|
||||
>
|
||||
struct GettEpilogueParams {
|
||||
using ElementScalar = ElementScalar_;
|
||||
using ElementScalingFactor = ElementScalingFactor_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TensorC = TensorC_;
|
||||
using TensorD = TensorD_;
|
||||
using TensorAux = TensorAux_;
|
||||
using VectorBias = VectorBias_;
|
||||
using VectorAlpha = VectorAlpha_;
|
||||
using VectorBeta = VectorBeta_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
using BiasBinaryOp = BiasBinaryOp_;
|
||||
|
||||
using EngineC = typename TensorC::engine_type;
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
TensorC C{};
|
||||
TensorD D{};
|
||||
VectorBias Bias{};
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
|
||||
ElementScalingFactor scale_a = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_b = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_c = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_d = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
||||
|
||||
bool beta_per_channel_scaling = false;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gett(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
|
||||
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
|
||||
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
|
||||
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
|
||||
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
|
||||
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
|
||||
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
|
||||
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
|
||||
gett_mainloop(mainloop_params, m, n, l, acc);
|
||||
gett_epilogue(epilogue_params, m, n, l, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Mainloop
|
||||
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_mainloop(
|
||||
MainloopParams const& mainloop_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
|
||||
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
|
||||
|
||||
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
||||
RingOp fma_op;
|
||||
|
||||
multiplies<ElementAccumulator> scale_op;
|
||||
|
||||
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
|
||||
|
||||
// Tempo accumulators to seperate blockwise accumulation
|
||||
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
|
||||
|
||||
// Zero out accumulators
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t block_m = m / kBlockM;
|
||||
int64_t block_n = n / kBlockN;
|
||||
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
|
||||
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l);
|
||||
|
||||
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
|
||||
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
|
||||
// Load Blockwise scaling factor from blockscale Tensors for B
|
||||
int64_t block_k = k / kBlockK;
|
||||
cute::Tensor scale_a = blockscale_A(_, block_k);
|
||||
ElementBlockScaleB scale_b = blockscale_B[block_k];
|
||||
|
||||
// Load A
|
||||
ElementAccumulator a_frag[kBlockM];
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
} else {
|
||||
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// Load B
|
||||
ElementAccumulator b_frag[kBlockN];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
} else {
|
||||
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Groupwise-scaling at kBlockK boundary
|
||||
// (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a[m_b / ScaleGranularityM] * scale_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Epilogue
|
||||
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_epilogue(
|
||||
EpilogueParams const& epilogue_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
||||
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
||||
using ElementScalar = typename EpilogueParams::ElementScalar;
|
||||
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
|
||||
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsReLUAuxNeeded =
|
||||
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
||||
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
||||
constexpr bool IsClamp =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
|
||||
|
||||
constexpr bool IsBackpropFusion =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
|
||||
|
||||
// Input related converter
|
||||
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
||||
NumericConverter<ElementCompute, ElementC> source_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
|
||||
|
||||
// Abs max converter
|
||||
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
|
||||
|
||||
// Output related converter
|
||||
NumericConverter<ElementD, ElementCompute> destination_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
||||
|
||||
// Epilogue operations
|
||||
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
|
||||
multiplies<ElementCompute> mul;
|
||||
plus<ElementCompute> add;
|
||||
|
||||
// Activation operation
|
||||
ActivationFunctor activation;
|
||||
|
||||
// Bias binary operation
|
||||
BiasBinaryOp bias_op;
|
||||
|
||||
// Do conversion
|
||||
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
|
||||
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
|
||||
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
|
||||
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
|
||||
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
|
||||
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
|
||||
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
|
||||
|
||||
// Init local var
|
||||
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
|
||||
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
|
||||
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
ElementCompute inter_accum[kBlockM][kBlockN];
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
ElementCompute local_dBias = ElementCompute(0);
|
||||
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
||||
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
||||
// per-row alpha
|
||||
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
||||
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
|
||||
}
|
||||
ElementCompute output = mul(converted_alpha, converted_acc);
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
|
||||
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
|
||||
output = bias_op(output, converted_bias);
|
||||
}
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.C.data())) {
|
||||
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
||||
// per-row beta
|
||||
if (epilogue_params.Vbeta.data()) {
|
||||
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
|
||||
}
|
||||
output = epilogue_fma(converted_beta, converted_src, output);
|
||||
}
|
||||
|
||||
if constexpr (IsBackpropFusion) {
|
||||
ElementAux aux_input = ElementAux(0);
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
|
||||
}
|
||||
|
||||
output = activation(output, aux_source_converter(aux_input));
|
||||
local_dBias = add(local_dBias, output);
|
||||
}
|
||||
else {
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
auto aux_output = output;
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
|
||||
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
|
||||
}
|
||||
|
||||
if constexpr (IsReLUAuxNeeded) {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
|
||||
} else {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsClamp) { // Treat Clamp as ReLU
|
||||
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
|
||||
}
|
||||
else {
|
||||
output = activation(output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_output = amax_op(local_abs_max_output, output);
|
||||
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
||||
}
|
||||
|
||||
inter_accum[m_b][n_b] = ElementCompute(output);
|
||||
}
|
||||
} // n_b
|
||||
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
|
||||
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
|
||||
local_dBias = add(local_dBias, converted_dBias);
|
||||
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
{
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_D) {
|
||||
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_Aux) {
|
||||
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM - General Matrix-Matrix contraction without conjugation options
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gemm3x(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
|
||||
"with Batchmode are supported");
|
||||
// Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett).
|
||||
Gett(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // cutlass::reference::host
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Reference in New Issue
Block a user