Blockwise and Groupwise GEMM for Blackwell and Improvements for Hopper (#2139)

- Blockwise and Groupwise GEMM improvements for Hopper.
- Blockwise and Groupwise GEMM for Blackwell.
- Blockwise Grouped GEMM for Hopper.
- Static ScalePromotionInterval for Hopper FP8 GEMMs.

Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
dePaul Miller
2025-02-26 09:44:58 -08:00
committed by GitHub
parent eefa171318
commit ca4fdbea70
28 changed files with 6860 additions and 71 deletions

View File

@ -398,6 +398,10 @@ void initialize(const Options<RasterOrderOptions> &options) {
blockscale_tensor_A.sync_device();
blockscale_tensor_B.sync_device();
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
// Else kernel will fail can_implement() check
// Deprecation Notice : We plan to remove this params member in an upcoming release
// Users can safely delete this line from their code, since the default is already 4
mma_promotion_interval = 4;
if (options.save_aux) {
@ -662,9 +666,11 @@ int run(Options<RasterOrderOptions> &options)
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
if (options.verify) {
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
// if (!result.passed) {
// exit(-1);
@ -674,8 +680,9 @@ int run(Options<RasterOrderOptions> &options)
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
if (iter == options.warmup)
timer.start();
CUTLASS_CHECK(gemm.run());
}
timer.stop();

View File

@ -453,6 +453,10 @@ void initialize(const Options<RasterOrderOptions> &options) {
blockscale_tensor_A.sync_device();
blockscale_tensor_B.sync_device();
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
// Else kernel will fail can_implement() check
// Deprecation Notice : We plan to remove this params member in an upcoming release
// Users can safely delete this line from their code, since the default is already 4
mma_promotion_interval = 4;
if (options.save_aux) {
@ -668,14 +672,14 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
if (false) {
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
}
#if 0
std::cout << "tensor_ref_D.host_view() {" << std::endl
<< tensor_ref_D.host_view() << std::endl
<< "}" << std::endl;
std::cout << "tensor_D.host_view() {" << std::endl
<< tensor_D.host_view() << std::endl
<< "}" << std::endl;
#endif
if (IsDFp8 && options.save_amax) {
abs_max_D.sync_host();
@ -729,13 +733,15 @@ int run(Options<RasterOrderOptions> &options)
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
if (options.verify) {
result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
}
// if (!result.passed) {
// exit(-1);
// }
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)

View File

@ -34,6 +34,7 @@ template<typename RasterOrderOptions>
struct Options {
bool help = false;
bool verify = true;
float alpha = 1.f, beta = 0.f;
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
@ -41,6 +42,7 @@ struct Options {
bool save_aux = true;
bool save_amax = true;
int iterations = 1000;
int warmup = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
RasterOrderOptions raster;
int swizzle;
@ -68,7 +70,9 @@ struct Options {
cmd.get_cmd_line_argument("device_scale", device_scale, false);
cmd.get_cmd_line_argument("save_aux", save_aux, true);
cmd.get_cmd_line_argument("save_amax", save_amax, true);
cmd.get_cmd_line_argument("warmup", warmup);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("verify", verify);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
@ -89,8 +93,8 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n"
<< " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
@ -113,7 +117,7 @@ struct Options {
out
<< "\n\nExamples:\n\n"
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
<< "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}

View File

@ -0,0 +1,841 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
which are more efficient than the Ampere tensor core instructions.
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
descriptors to move between groups/problem_count (represented by groups).
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling \
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
Where the test_benchmark.txt may look as such:
0 256x512x128
1 256x512x512
2 512x256x128
3 256x256x128
4 256x512x1024
5 1024x512x128 and so on
*/
#include <iostream>
#include <optional>
#include <fstream>
#include <sstream>
#include <vector>
#include <cfloat>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/tensor_fill.h"
// Includes from examples directory
#include "helper.h"
#include "hopper_fp8_commandline.hpp"
#include "reference/host/gemm_with_groupwise_scaling.h"
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementBlockScale = float; // Element type for blockscaling during accumulation
using ElementCompute = float; // Element type for epilogue computation
using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor
// Given TileShape = Shape<_128,_128,_128>:
// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor)
// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling
// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling
// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling
template <int ScaleGranularityM_, int ScaleGranularityN_>
struct GroupScaleConfig {
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
static constexpr int ScaleGranularityM = ScaleGranularityM_;
static constexpr int ScaleGranularityN = ScaleGranularityN_;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
"FP8 scaling granularity must evenly divide tile shape along M.");
static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile,
"FP8 scaling granularity must evenly divide tile shape along N.");
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
};
using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>;
using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>;
using GroupScale2D1DConfig = GroupScaleConfig<size<0>(TileShape_{}), 1>;
using GroupScale2D2DConfig = GroupScaleConfig<size<0>(TileShape_{}), size<1>(TileShape_{})>;
template <typename ScheduleConfig>
struct GroupScaleGemm {
using ArchTag = typename ScheduleConfig::ArchTag;
using OperatorClass = typename ScheduleConfig::OperatorClass;
using TileShape = typename ScheduleConfig::TileShape;
using ClusterShape = typename ScheduleConfig::ClusterShape;
using KernelSchedule = typename ScheduleConfig::KernelSchedule;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using FusionOperation = typename ScheduleConfig::FusionOperation;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC *, AlignmentC,
ElementD, LayoutD *, AlignmentD,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopWithGroupWiseScaling,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
using GroupScale1D1DGemm = GroupScaleGemm<GroupScale1D1DConfig>;
using GroupScale1D2DGemm = GroupScaleGemm<GroupScale1D2DConfig>;
using GroupScale2D1DGemm = GroupScaleGemm<GroupScale2D1DConfig>;
using GroupScale2D2DGemm = GroupScaleGemm<GroupScale2D2DConfig>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename GroupScale1D1DGemm::Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
using StrideA = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideA;
using StrideB = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideB;
using StrideC = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideC;
using StrideD = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideD;
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
"ElementAccumulator and ElementBlockScale should be same datatype");
/// Initialization
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> offset_blockscale_A;
std::vector<int64_t> offset_blockscale_B;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
uint64_t seed;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementD> block_D;
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
cutlass::DeviceAllocation<const ElementA *> ptr_A;
cutlass::DeviceAllocation<const ElementB *> ptr_B;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<ElementD *> ptr_D;
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023,
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
double _scope_max, _scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
_scope_max = 2;
_scope_min = 0;
} else if (bits_input <= 8) {
_scope_max = 2;
_scope_min = -2;
} else if (bits_input == 16) {
_scope_max = 5;
_scope_min = -5;
} else {
_scope_max = 8;
_scope_min = -8;
}
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
_scope_max = scope_max;
}
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
_scope_min = scope_min;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
return true;
}
/// Allocates device-side data
template <typename OptionType>
void allocate(const OptionType &options) {
using TileShape = typename OptionType::GroupScaleConfig::TileShape;
const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile;
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
int64_t total_elements_blockscale_A = 0;
int64_t total_elements_blockscale_B = 0;
offset_A.clear();
offset_B.clear();
offset_C.clear();
offset_D.clear();
offset_blockscale_A.clear();
offset_blockscale_B.clear();
stride_A_host.clear();
stride_B_host.clear();
stride_C_host.clear();
stride_D_host.clear();
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{})));
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
auto blockscale_k = cute::get<2>(blockscale_shape);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
offset_blockscale_A.push_back(total_elements_blockscale_A);
offset_blockscale_B.push_back(total_elements_blockscale_B);
int64_t elements_A = M * K;
int64_t elements_B = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
int64_t elements_blockscale_A = groupscale_m * blockscale_k;
int64_t elements_blockscale_B = groupscale_n * blockscale_k;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
total_elements_blockscale_A += elements_blockscale_A;
total_elements_blockscale_B += elements_blockscale_B;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
blockscale_block_A.reset(total_elements_blockscale_A);
blockscale_block_B.reset(total_elements_blockscale_B);
}
/// Initialize operands to be used in the GEMM and reference GEMM
template <typename OptionType>
void initialize(const OptionType &options) {
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementD *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
alpha_host.clear();
beta_host.clear();
for (int i = 0; i < options.groups; i++) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
ptr_blockscale_A.reset(options.groups);
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
ptr_blockscale_B.reset(options.groups);
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2022);
initialize_block(block_B, seed + 2023);
initialize_block(block_C, seed + 2024);
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
template<typename GemmArguments, typename OptionType>
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
{
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
int device_id = 0;
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename GroupScale1D1DGemm::Gemm::GemmKernel>(device_id);
GemmArguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
ptr_blockscale_A.get(),
ptr_blockscale_B.get()
},
{
{}, // epilogue.thread
ptr_C.get(), stride_C.get(),
ptr_D.get(), stride_D.get()
},
kernel_hw_info
};
auto &fusion_args = arguments.epilogue.thread;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
arguments.scheduler.raster_order = options.raster;
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
arguments.scheduler.max_swizzle_size = options.swizzle;
return arguments;
}
template <typename OptionType>
bool verify(const OptionType &options) {
//
// Compute reference output
//
std::vector<ElementA> block_A_host(block_A.size());
std::vector<ElementB> block_B_host(block_B.size());
std::vector<ElementC> block_C_host(block_C.size());
std::vector<ElementD> block_D_host_kernel(block_D.size());
std::vector<ElementD> block_D_host_ref(block_D.size());
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
block_A.copy_to_host(block_A_host.data());
block_B.copy_to_host(block_B_host.data());
block_C.copy_to_host(block_C_host.data());
block_D.copy_to_host(block_D_host_kernel.data());
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
bool passed = true;
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
auto gemm_problem_shape = cute::make_shape(m, n, k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
auto groupscale_m = blockscale_m * OptionType::GroupScaleConfig::ScaleMsPerTile;
auto groupscale_n = blockscale_n * OptionType::GroupScaleConfig::ScaleNsPerTile;
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
cute::make_layout(
cute::make_shape(m, k, 1),
stride_A_host.at(group_idx)
)
);
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
cute::make_layout(
cute::make_shape(n, k, 1),
stride_B_host.at(group_idx)
)
);
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
cute::make_layout(
cute::make_shape(m, n, 1),
stride_C_host.at(group_idx)
)
);
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
cute::make_layout(
cute::make_shape(m, n, 1),
stride_D_host.at(group_idx)
)
);
auto blockscale_A = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
cute::make_layout(
cute::make_shape(groupscale_m, blockscale_k, 1),
cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k)
)
);
auto blockscale_B = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
cute::make_layout(
cute::make_shape(groupscale_n, blockscale_k, 1),
cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k)
)
);
using unused_t = decltype(D);
cutlass::reference::host::GettMainloopParams<
ElementAccumulator,
decltype(A),
decltype(B),
decltype(blockscale_A),
decltype(blockscale_B),
TileShape_
> mainloop_params{
A, B, // Operand Tensors
blockscale_A, blockscale_B // Groupwise scaling Tensors
};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
unused_t, // Aux
unused_t, // valpha
unused_t, // vbeta
ActivationFunctor
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = alpha_host.at(group_idx);
epilogue_params.beta = beta_host.at(group_idx);
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Check if output from CUTLASS kernel and reference kernel are equal or not
auto this_group_passed = std::equal(
// std::execution::par_unseq,
block_D_host_ref.data() + offset_D.at(group_idx),
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
block_D_host_kernel.data() + offset_D.at(group_idx)
);
passed &= this_group_passed;
#if 0
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
#endif
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm, typename OptionType>
int run(OptionType &options, bool host_problem_shapes_available = true)
{
using TileShape = typename OptionType::GroupScaleConfig::TileShape;
const int ScaleGranularityM = OptionType::GroupScaleConfig::ScaleGranularityM;
const int ScaleGranularityN = OptionType::GroupScaleConfig::ScaleGranularityN;
const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile;
allocate(options);
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::string raster = "Heuristic";
if (options.raster == RasterOrderOptions::AlongN) {
raster = "Along N";
}
else if (options.raster == RasterOrderOptions::AlongM) {
raster = "Along M";
}
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
fflush(stdout);
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
std::cerr << "This example requires CUDA 12.3 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
//
// Parse options
//
Options<RasterOrderOptions, ProblemShape, GroupScale1D1DConfig> options_1d1d;
Options<RasterOrderOptions, ProblemShape, GroupScale1D2DConfig> options_1d2d;
Options<RasterOrderOptions, ProblemShape, GroupScale2D1DConfig> options_2d1d;
Options<RasterOrderOptions, ProblemShape, GroupScale2D2DConfig> options_2d2d;
options_1d1d.parse(argc, args);
options_1d2d.parse(argc, args);
options_2d1d.parse(argc, args);
options_2d2d.parse(argc, args);
if (options_1d1d.help) {
options_1d1d.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
auto run_tests = [&] (bool host_problem_shapes_available = true) {
std::cout << "Grouped GEMM kernel with 1D1D group scale" << std::endl;
run<GroupScale1D1DGemm::Gemm>(options_1d1d, host_problem_shapes_available);
std::cout << "Grouped GEMM kernel with 1D2D group scale" << std::endl;
run<GroupScale1D2DGemm::Gemm>(options_1d2d, host_problem_shapes_available);
std::cout << "Grouped GEMM kernel with 2D1D group scale" << std::endl;
run<GroupScale2D1DGemm::Gemm>(options_2d1d, host_problem_shapes_available);
std::cout << "Grouped GEMM kernel with 2D2D group scale" << std::endl;
run<GroupScale2D2DGemm::Gemm>(options_2d2d, host_problem_shapes_available);
std::cout << std::endl;
};
std::cout << "Running tests with host problem shapes:" << std::endl;
run_tests(true);
std::cout << "Running tests without host problem shapes:" << std::endl;
run_tests(false);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,61 @@
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=0) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=512 --groups=50 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
cutlass_example_add_executable(
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
)

View File

@ -0,0 +1,211 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Command line options parsing
template<typename _RasterOrderOptions, typename _ProblemShape, typename _GroupScaleConfig>
struct Options {
using RasterOrderOptions = _RasterOrderOptions;
using ProblemShape = _ProblemShape;
using GroupScaleConfig = _GroupScaleConfig;
bool help = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, groups = 10;
std::string benchmark_path;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
int const tma_alignment_bits = 128;
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
int const k_alignment = 128;
int const m_alignment = 128;
int const n_alignment = 128;
RasterOrderOptions raster;
int swizzle;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
char raster_char;
cmd.get_cmd_line_argument("raster", raster_char);
if (raster_char == 'N' || raster_char == 'n') {
raster = RasterOrderOptions::AlongN;
}
else if (raster_char == 'M' || raster_char == 'm') {
raster = RasterOrderOptions::AlongM;
}
else if (raster_char == 'H' || raster_char == 'h') {
raster = RasterOrderOptions::Heuristic;
}
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
// Decide how to initialize the problems
if (!benchmark_path.empty()) {
if (!benchmark_problems()) {
problem_sizes_host.clear();
return;
}
}
else {
randomize_problems(cmd);
}
}
void randomize_problems(cutlass::CommandLine &cmd) {
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
}
if (n < 1) {
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
}
if (k < 1) {
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
}
problem_sizes_host.push_back({m, n, k});
}
}
/// Load a benchmark
bool benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file.good()) {
return false;
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
cutlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
}
groups = static_cast<int>(problem_sizes_host.size());
return true;
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n"
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Number of real-valued multiply-adds
uint64_t fmas = 0ull;
for (auto const [m, n, k] : problem_sizes_host) {
fmas += static_cast<uint64_t>(m) *
static_cast<uint64_t>(n) *
static_cast<uint64_t>(k);
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};

View File

@ -0,0 +1,520 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for GETT in host-side code.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/gemm.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/relatively_equal.h"
#include <iostream>
#include "cute/tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::reference::host {
template<class T, class = void>
struct ElementTraits {
using type = T;
};
template<class T>
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
using type = decltype(std::declval<T>().get());
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementAccumulator_,
class TensorA_, // (M, K, L)
class TensorB_, // (N, K, L)
class TensorScaleA_, // (m, k, L)
class TensorScaleB_, // (n, k, L)
class TileShape_
>
struct GettMainloopParams {
using ElementAccumulator = ElementAccumulator_;
using TensorA = TensorA_;
using TensorB = TensorB_;
using EngineA = typename TensorA::engine_type;
using LayoutA = typename TensorA::layout_type;
using EngineB = typename TensorB::engine_type;
using LayoutB = typename TensorB::layout_type;
using TensorScaleA = TensorScaleA_;
using TensorScaleB = TensorScaleB_;
using TileShape = TileShape_;
using EngineScaleA = typename TensorScaleA::engine_type;
using EngineScaleB = typename TensorScaleB::engine_type;
TensorA A{};
TensorB B{};
TensorScaleA ScaleA{};
TensorScaleB ScaleB{};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
class ElementScalar_,
class ElementScalingFactor_,
class ElementAccumulator_,
class ElementCompute_,
class TensorC_, // (M, N, L)
class TensorD_, // (M, N, L)
class VectorBias_ = TensorD_, // (M, 1)
class TensorAux_ = TensorD_, // (M, N, L)
class VectorAlpha_ = TensorD_, // (M, 1)
class VectorBeta_ = VectorAlpha_, // (M, 1)
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
bool PerColumnBias_ = false
>
struct GettEpilogueParams {
using ElementScalar = ElementScalar_;
using ElementScalingFactor = ElementScalingFactor_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using TensorC = TensorC_;
using TensorD = TensorD_;
using TensorAux = TensorAux_;
using VectorBias = VectorBias_;
using VectorAlpha = VectorAlpha_;
using VectorBeta = VectorBeta_;
using ActivationFunctor = ActivationFunctor_;
using BiasBinaryOp = BiasBinaryOp_;
using EngineC = typename TensorC::engine_type;
using LayoutC = typename TensorC::layout_type;
using EngineD = typename TensorD::engine_type;
using LayoutD = typename TensorD::layout_type;
static constexpr bool PerColumnBias = PerColumnBias_;
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
TensorC C{};
TensorD D{};
VectorBias Bias{};
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
ElementScalingFactor scale_a = ElementScalingFactor(1);
ElementScalingFactor scale_b = ElementScalingFactor(1);
ElementScalingFactor scale_c = ElementScalingFactor(1);
ElementScalingFactor scale_d = ElementScalingFactor(1);
ElementScalingFactor scale_aux = ElementScalingFactor(1);
bool beta_per_channel_scaling = false;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling
template <
class MainloopParams,
class EpilogueParams
>
void Gett(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
#if defined(_OPENMP)
#pragma omp parallel for collapse(3)
#endif
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
gett_mainloop(mainloop_params, m, n, l, acc);
gett_epilogue(epilogue_params, m, n, l, acc);
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Mainloop
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_mainloop(
MainloopParams const& mainloop_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
RingOp fma_op;
multiplies<ElementAccumulator> scale_op;
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
// Tempo accumulators to seperate blockwise accumulation
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
// Zero out accumulators
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
const int M = cute::size<0>(mainloop_params.A.layout());
const int N = cute::size<0>(mainloop_params.B.layout());
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA.layout());
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB.layout());
assert(ScaleGranularityM && M % ScaleGranularityM == 0 && "ScaleGranularityM must divide M");
assert(ScaleGranularityN && N % ScaleGranularityN == 0 && "ScaleGranularityN must divide N");
cute::Tensor blockscale_A = domain_offset(make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l));
cute::Tensor blockscale_B = domain_offset(make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l));
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
// Load Blockwise scaling factor from blockscale Tensors for B
int64_t block_k = k / kBlockK;
cute::Tensor scale_a = blockscale_A(_, block_k);
cute::Tensor scale_b = blockscale_B(_, block_k);
// Load A
ElementAccumulator a_frag[kBlockM];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
} else {
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// Load B
ElementAccumulator b_frag[kBlockN];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
} else {
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
}
}
// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
}
}
// Apply Groupwise-scaling at kBlockK boundary
// (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
// (b) Zero-out partial temporary (acc_temp),
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GETT - Epilogue
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
void gett_epilogue(
EpilogueParams const& epilogue_params,
int64_t m,
int64_t n,
int64_t l,
ElementAccumulator (&acc)[kBlockM][kBlockN])
{
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
using cute::raw_pointer_cast;
using ElementCompute = typename EpilogueParams::ElementCompute;
using ElementC = typename EpilogueParams::TensorC::value_type;
using ElementD = typename EpilogueParams::TensorD::value_type;
using ElementAux = typename EpilogueParams::TensorAux::value_type;
using ElementBias = typename EpilogueParams::VectorBias::value_type;
using ElementScalar = typename EpilogueParams::ElementScalar;
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
constexpr bool IsScalingAndAmaxOutputNeeded =
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
constexpr bool IsReLUAuxNeeded =
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
constexpr bool IsClamp =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
constexpr bool IsBackpropFusion =
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
// Input related converter
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
NumericConverter<ElementCompute, ElementC> source_converter;
NumericConverter<ElementCompute, ElementBias> bias_converter;
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
// Scale related converter
NumericConverter<ElementCompute, ElementScalar> scale_converter;
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
// Abs max converter
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
// Output related converter
NumericConverter<ElementD, ElementCompute> destination_converter;
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
NumericConverter<ElementBias, ElementCompute> dBias_converter;
// Epilogue operations
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
multiplies<ElementCompute> mul;
plus<ElementCompute> add;
// Activation operation
auto activation = [] (ElementCompute x, ElementCompute y = ElementCompute(0)) {
if constexpr (std::is_same_v<ActivationFunctor, void>) {
return x + y;
} else {
return ActivationFunctor()(x, y);
}
};
// Bias binary operation
BiasBinaryOp bias_op;
// Do conversion
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
// Init local var
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
converted_beta = mul(converted_beta, converted_scale_c);
ElementCompute inter_accum[kBlockM][kBlockN];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
ElementCompute local_dBias = ElementCompute(0);
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
// per-row alpha
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
}
ElementCompute output = mul(converted_alpha, converted_acc);
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
output = bias_op(output, converted_bias);
}
if (raw_pointer_cast(epilogue_params.C.data())) {
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
// per-row beta
if (epilogue_params.Vbeta.data()) {
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
}
output = epilogue_fma(converted_beta, converted_src, output);
}
if constexpr (IsBackpropFusion) {
ElementAux aux_input = ElementAux(0);
if (raw_pointer_cast(epilogue_params.Aux.data())) {
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
}
output = activation(output, aux_source_converter(aux_input));
local_dBias = add(local_dBias, output);
}
else {
if (raw_pointer_cast(epilogue_params.Aux.data())) {
auto aux_output = output;
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
}
if constexpr (IsReLUAuxNeeded) {
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
} else {
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
}
}
if constexpr (IsClamp) { // Treat Clamp as ReLU
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
}
else {
output = activation(output);
}
}
if constexpr (IsScalingAndAmaxOutputNeeded) {
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
local_abs_max_output = amax_op(local_abs_max_output, output);
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
}
inter_accum[m_b][n_b] = ElementCompute(output);
}
} // n_b
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
local_dBias = add(local_dBias, converted_dBias);
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
}
}
} // m_b
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
}
}
}
#if defined(_OPENMP)
#pragma omp critical(Abs_Max_Data_Update)
#endif
{
if constexpr (IsScalingAndAmaxOutputNeeded) {
if (epilogue_params.abs_max_D) {
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
}
}
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
if (epilogue_params.abs_max_Aux) {
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM - General Matrix-Matrix contraction without conjugation options
template <
class MainloopParams,
class EpilogueParams
>
void Gemm3x(
MainloopParams const& mainloop_params,
EpilogueParams const& epilogue_params)
{
using namespace cute;
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
"with Batchmode are supported");
// Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett).
Gett(mainloop_params, epilogue_params);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // cutlass::reference::host
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,585 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_128,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutC, AlignmentD,
cutlass::epilogue::TmaWarpSpecialized1Sm
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
// Strides just iterate over scalars and have no zeros
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
// Layouts are tiled to the problem size and the strides have zeros
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "81_blackwell_gemm_blockwise\n\n"
<< " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "81_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -8;
scope_max = 8;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
tensor_SFA.resize(blockscale_a_coord);
tensor_SFB.resize(blockscale_b_coord);
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A,
tensor_B.device_data(), stride_B,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least sm100a.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,589 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
using MmaTileShape_MNK = Shape<_128,_128,_128>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_1,_1,_1>;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 128;
constexpr int ScaleGranularityK = 128;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a
// 16B alignment if possible (i.e., we have at least 16B of scales in M).
// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B
// and transpose A, B, C, and scales since B^T A^T = C^T.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutC, AlignmentD,
cutlass::epilogue::TmaWarpSpecialized1Sm
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
ElementAccumulator,
MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
// Strides just iterate over scalars and have no zeros
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
// Layouts are tiled to the problem size and the strides have zeros
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA> tensor_A;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFA;
cutlass::HostTensor<ElementB , LayoutB> tensor_B;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::PackedVectorLayout> tensor_SFB;
cutlass::HostTensor<ElementC , LayoutC> tensor_C;
cutlass::HostTensor<ElementD , LayoutD> tensor_D;
cutlass::HostTensor<ElementD , LayoutD> tensor_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
bool skip_verification = false;
float alpha = 1.f, beta = 0.f;
int iterations = 1000;
int m = 1024, n = 512, k = 1024, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("skip-verification")) {
skip_verification = true;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "81_blackwell_gemm_groupwise\n\n"
<< " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --skip-verification Skip verification.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "81_blackwell_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Helper to initialize a block of device data (scale_tensors)
template <typename Element, typename Layout>
bool initialize_scale_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
scope_min = -8;
scope_max = 8;
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
}
else {
throw std::runtime_error("Not implementated.");
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
using namespace cute;
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
tensor_SFA.resize(blockscale_a_coord);
tensor_SFB.resize(blockscale_b_coord);
initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022);
initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023);
initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024);
initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025);
initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_SFA.sync_device();
tensor_SFB.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A,
tensor_B.device_data(), stride_B,
tensor_SFA.device_data(), layout_SFA,
tensor_SFB.device_data(), layout_SFB},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA);
auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB);
using unused_t = decltype(D);
cutlass::reference::host::GettBlockScalingMainloopParams<
ElementAccumulator,
decltype(A),
decltype(SFA),
decltype(B),
decltype(SFB)
> mainloop_params{A, SFA, B, SFB};
cutlass::reference::host::GettEpilogueParams<
ElementAccumulator,
ElementAccumulator,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D)
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
Result result;
if (!options.skip_verification) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least sm100a.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 10 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Run
//
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
run<Gemm>(options);
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,57 @@
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
set(TEST_RANDOM --iterations=0) # Random problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
set(TEST_SMALL --m=256 --n=128 --k=128 --iterations=0) # Small problem sizes
cutlass_example_add_executable(
81_blackwell_gemm_blockwise
81_blackwell_gemm_blockwise.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_EPILOGUE
TEST_SMALL
)
cutlass_example_add_executable(
81_blackwell_gemm_groupwise
81_blackwell_gemm_groupwise.cu
TEST_COMMAND_OPTIONS
TEST_RANDOM
TEST_EPILOGUE
TEST_SMALL
)
endif()

View File

@ -146,6 +146,7 @@ foreach(EXAMPLE
64_ada_fp8_gemm_grouped
65_distributed_gemm
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
69_hopper_mixed_dtype_grouped_gemm
70_blackwell_gemm
71_blackwell_gemm_with_collective_builder
@ -156,6 +157,7 @@ foreach(EXAMPLE
76_blackwell_conv
77_blackwell_fmha
78_blackwell_emulated_bf16x9_gemm
81_blackwell_gemm_blockwise
)
add_subdirectory(${EXAMPLE})

View File

@ -0,0 +1,189 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cute/int_tuple.hpp"
#include "cute/atom/mma_traits_sm100.hpp"
namespace cutlass::detail{
/////////////////////////////////////////////////////////////////////////////////////////////////
using namespace cute;
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
struct Sm100BlockwiseScaleConfig {
using ShapeSFA = Shape<Shape<Int<SFVecSizeM>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
using ShapeSFB = Shape<Shape<Int<SFVecSizeN>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
using StrideSFA = conditional_t<majorSFA == UMMA::Major::MN,
Stride<Stride<_0,_1>,Stride<_0,int32_t>, int32_t>,
Stride<Stride<_0,int32_t>,Stride<_0,_1>, int32_t>>;
using StrideSFB = conditional_t<majorSFB == UMMA::Major::MN,
Stride<Stride<_0,_1>,Stride<_0,int32_t>, int32_t>,
Stride<Stride<_0,int32_t>,Stride<_0,_1>, int32_t>>;
using LayoutSFA = Layout<ShapeSFA, StrideSFA>;
using LayoutSFB = Layout<ShapeSFB, StrideSFB>;
CUTE_HOST_DEVICE
static constexpr auto
deduce_layoutSFA() {
return LayoutSFA{};
}
template<typename CtaShape_MNK>
CUTE_HOST_DEVICE
static constexpr auto
smem_atom_layoutSFA(CtaShape_MNK cta_shape_mnk) {
static_assert(cute::is_static_v<CtaShape_MNK>, "Expect static CTA shape");
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
auto [M, N, K] = cta_shape_mnk;
if constexpr (majorSFA == UMMA::Major::MN) {
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int<cute::ceil_div(size<0>(CtaShape_MNK{}), SFVecSizeM)>{}));
}
else {
return make_stride(make_stride(_0{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{}));
}
}();
auto [M, N, K] = cta_shape_mnk;
return make_layout(
make_shape(make_shape(Int<SFVecSizeM>{}, Int<cute::ceil_div(size<0>(CtaShape_MNK{}), SFVecSizeM)>{}),
make_shape(Int<SFVecSizeK>{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{})),
strides
);
}
CUTE_HOST_DEVICE
static constexpr auto
deduce_layoutSFB() {
return LayoutSFB{};
}
template<typename CtaShape_MNK>
CUTE_HOST_DEVICE
static constexpr auto
smem_atom_layoutSFB(CtaShape_MNK cta_shape_mnk) {
static_assert(cute::is_static_v<CtaShape_MNK>, "Expect static CTA shape");
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
if constexpr (majorSFA == UMMA::Major::MN) {
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int<cute::ceil_div(size<1>(CtaShape_MNK{}), SFVecSizeN)>{}));
}
else {
return make_stride(make_stride(_0{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{}));
}
}();
auto [M, N, K] = cta_shape_mnk;
return make_layout(
make_shape(make_shape(Int<SFVecSizeN>{}, Int<cute::ceil_div(size<1>(CtaShape_MNK{}), SFVecSizeN)>{}),
make_shape(Int<SFVecSizeK>{}, Int<cute::ceil_div(size<2>(CtaShape_MNK{}), SFVecSizeK)>{})),
strides
);
}
// The following function is provided for user fill dynamic problem size to the layout_SFA.
template <class ProblemShape>
CUTE_HOST_DEVICE
static constexpr auto
tile_atom_to_shape_SFA(ProblemShape problem_shape) {
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
auto [M, N, K, L] = problem_shape_MNKL;
if constexpr (majorSFA == UMMA::Major::MN) {
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, SFVecSizeM)));
}
else {
return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{}));
}
}();
auto [M, N, K, L] = problem_shape_MNKL;
auto mk_layout = make_layout(
make_shape(make_shape(Int<SFVecSizeM>{}, cute::ceil_div(M, SFVecSizeM)),
make_shape(Int<SFVecSizeK>{}, cute::ceil_div(K, SFVecSizeK))),
strides
);
return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout))));
}
// The following function is provided for user fill dynamic problem size to the layout_SFB.
template <class ProblemShape>
CUTE_HOST_DEVICE
static constexpr auto
tile_atom_to_shape_SFB(ProblemShape problem_shape) {
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE {
auto [M, N, K, L] = problem_shape_MNKL;
if constexpr (majorSFB == UMMA::Major::MN) {
return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, SFVecSizeN)));
}
else {
return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{}));
}
}();
auto [M, N, K, L] = problem_shape_MNKL;
auto nk_layout = make_layout(
make_shape(make_shape(Int<SFVecSizeN>{}, cute::ceil_div(N, SFVecSizeN)),
make_shape(Int<SFVecSizeK>{}, cute::ceil_div(K, SFVecSizeK))),
strides
);
return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout))));
}
};
template<class MmaTileShape_MNK>
constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
return Sm100BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::detail

View File

@ -0,0 +1,304 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<
int CapacityBytes,
class ElementA,
class ElementB,
class ElementScalar,
class ScaleShapeMNK,
class TileShapeMNK,
class MainloopPipelineStorage,
class TransformLoadPipelineStorage,
class TransformPipelineStorage,
int stages
>
constexpr int
sm100_compute_stage_count_or_override_blockwise(StageCount<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<
int CapacityBytes,
class ElementA,
class ElementB,
class ElementScalar,
class ScaleShapeMNK,
class TileShapeMNK,
class MainloopPipelineStorage,
class TransformLoadPipelineStorage,
class TransformPipelineStorage,
int stages
>
constexpr int
sm100_compute_stage_count_or_override_blockwise(cute::Int<stages> stage_count) {
return stages;
}
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<
int CapacityBytes,
class ElementA,
class ElementB,
class ElementScalar,
class ScaleShapeMNK,
class TileShapeMNK,
class MainloopPipelineStorage,
class TransformLoadPipelineStorage,
class TransformPipelineStorage,
int carveout_bytes>
constexpr int
sm100_compute_stage_count_or_override_blockwise(StageCountAutoCarveout<carveout_bytes> stage_count) {
// For F8/F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t
// For Planar Complex, ElementA/B will be passed in as cutlass::complex<ElementARaw>
// Each stage include (CollectiveMma::SharedStorage)
// 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage)
// 2. one of each of the pipelines
constexpr auto pipeline_bytes = sizeof(MainloopPipelineStorage) +
sizeof(TransformLoadPipelineStorage) + sizeof(TransformPipelineStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
constexpr auto scale_bits = cute::sizeof_bits_v<ElementScalar>;
constexpr int stage_bytes =
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
static_cast<int>(pipeline_bytes);
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ElementA,
class GmemLayoutATagPair,
int AlignmentA,
class ElementB,
class GmemLayoutBTagPair,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType
>
struct CollectiveBuilder<
arch::Sm100,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATagPair,
AlignmentA,
ElementB,
GmemLayoutBTagPair,
AlignmentB,
ElementAccumulator,
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
StageCountType,
KernelScheduleType,
cute::enable_if_t<
not cute::is_tuple_v<ElementA> && not cute::is_tuple_v<ElementB> &&
not cute::is_complex_v<ElementA> && not cute::is_complex_v<ElementB> &&
cute::is_tuple_v<GmemLayoutATagPair> && cute::is_tuple_v<GmemLayoutBTagPair> &&
// Dense Gemm
cute::is_base_of_v<KernelScheduleSm100Blockwise, KernelScheduleType> &&
// Alignment check
detail::sm1xx_gemm_is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, KernelScheduleType>()>>
{
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
static_assert(detail::check_input_datatypes<ElementA, ElementB>(), "Incorrect input types");
using GmemLayoutATag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutATagPair{}))>;
using GmemLayoutSFATag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutATagPair{}))>;
using GmemLayoutBTag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutBTagPair{}))>;
using GmemLayoutSFBTag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutBTagPair{}))>;
static_assert(cute::depth(GmemLayoutSFATag{}) == 2 and cute::depth(GmemLayoutSFBTag{}) == 2,
"Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)");
static_assert(size<1,0>(GmemLayoutSFATag{}) == size<1, 0>(GmemLayoutSFBTag{}),
"SFA and SFB must have equivalent SF vector sizes along K");
static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
// Data type used by MMA instruction
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element<ElementA>());
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element<ElementB>());
static constexpr bool is_2sm = cute::is_base_of_v<KernelSchedule2Sm, KernelScheduleType> ||
(not cute::is_base_of_v<KernelSchedule1Sm, KernelScheduleType> &&
not cute::is_base_of_v<KernelSchedule2Sm, KernelScheduleType> &&
cute::is_static_v<ClusterShape_MNK> &&
cute::get<0>(ClusterShape_MNK{}) % 2 == 0 );
static_assert(detail::sm100_gemm_check_for_f8f6f4_mix8bit_requirement<ElementAMma, ElementBMma,
TileShape_MNK, ClusterShape_MNK,
UmmaMajorA, UmmaMajorB, KernelScheduleType, is_2sm>(),
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );
using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma<
ElementAMma, ElementBMma, ElementAccumulator,
decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK,
UmmaMajorA, UmmaMajorB, KernelScheduleType>());
using ElementAMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
using ElementBMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
using AtomThrID = typename TiledMma::AtomThrID;
using AtomThrShapeMNK = cute::Shape<decltype(cute::shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
using CtaTileShape_MNK = decltype(cute::shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
cute::size<2>(TileShape_MNK{}))));
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
cute::size<2>(TileShape_MNK{}))));
using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
static_assert(BlockTileA_K{} == BlockTileB_K{}, "Block tile Ks should be equal");
using SmemShape_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(shape<1>(TileShape_MNK{}), shape_div(shape<1>(TileShape_MNK{}), size<1>(TileShape_MNK{}) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(TileShape_MNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
ClusterShape_MNK{}, AtomThrID{}));
using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B(
ClusterShape_MNK{}, AtomThrID{}));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
UmmaMajorA, ElementAMma_SmemAllocType, SmemShape_M, SmemShape_K>());
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
UmmaMajorB, ElementBMma_SmemAllocType, SmemShape_N, SmemShape_K>());
static constexpr uint32_t TotalTmemRows = 128;
static constexpr uint32_t Sm100TmemCapacityColumns = 512;
static constexpr uint32_t TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns;
static constexpr uint32_t AccumulatorPipelineStageCount = (is_2sm || (!is_2sm && size(shape<0,0>(MmaShapeA_MK{}) > 64))) ?
TotalTmem / (cute::size<0>(CtaTileShape_MNK{}) * cute::size<1>(CtaTileShape_MNK{}))
: (Sm100TmemCapacityColumns / cute::size<1>(CtaTileShape_MNK{})) * 2; // 1SM MMA_M = 64 case
static_assert(AccumulatorPipelineStageCount > 0, "Accumulator pipeline stage count must be positive. This error probably means that TileShape_MNK and/or TiledMma::ThrLayoutVMNK are wrong.");
// Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding.
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
// Grouped GEMM (where Stride type is Stride*) does not use CLC based scheduler.
// SchedulerPipelineStageCount could be set to zero for Grouped GEMM, but we shouldn't define CLC Pipeline's barrier arrays of size zero.
static constexpr uint32_t SchedulerPipelineStageCount = cute::is_same_v<InternalStrideA, StrideA> ? (AccumulatorPipelineStageCount + 1) : 1;
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
ClusterShape_MNK,
AccumulatorPipelineStageCount,
SchedulerPipelineStageCount,
detail::CLCResponseSize,
false
>::KernelSmemCarveout;
// Reduce SMEM capacity available for buffers considering barrier allocations.
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage;
using TransformLoadPipelineStorage = typename cutlass::PipelineAsync<1>::SharedStorage;
using TransformPipelineStorage = typename cutlass::PipelineUmmaAsync<1>::SharedStorage;
static constexpr int ScaleGranularityM = size<0,0>(GmemLayoutSFATag{});
static constexpr int ScaleGranularityN = size<0,0>(GmemLayoutSFBTag{});
static constexpr int ScaleGranularityK = size<1,0>(GmemLayoutSFBTag{});
static_assert(size<0>(CtaTileShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape");
static_assert(size<1>(CtaTileShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape");
static_assert(size<2>(CtaTileShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape");
using BlockTileScale_M = Int<size<0>(TileShape_MNK{}) / ScaleGranularityM>;
using BlockTileScale_N = Int<size<1>(TileShape_MNK{}) / ScaleGranularityN>;
using BlockTileScale_K = Int<size<2>(TileShape_MNK{}) / ScaleGranularityK>;
using ScaleTileShape = cute::Shape<BlockTileScale_M, BlockTileScale_N, BlockTileScale_K>;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockwise<
Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType,
ElementAccumulator, ScaleTileShape, SmemTileShape, MainloopPipelineStorage,
TransformLoadPipelineStorage, TransformPipelineStorage>(StageCountType{});
static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, and scales.");
using DispatchPolicy = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling<
PipelineStages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape_MNK>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
cute::tuple<cutlass::gemm::TagToStrideA_t<GmemLayoutATag>, cutlass::gemm::TagToStrideA_t<GmemLayoutSFATag>>,
ElementB,
cute::tuple<cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, cutlass::gemm::TagToStrideB_t<GmemLayoutSFBTag>>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
void,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void,
cute::identity
>;
};
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1046,8 +1046,7 @@ template <
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
int ScaleGranularityM_,
int ScaleGranularityN_
class KernelScheduleType
>
struct CollectiveBuilder<
arch::Sm90,
@ -1062,11 +1061,16 @@ struct CollectiveBuilder<
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>,
KernelScheduleType,
cute::enable_if_t<
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
cute::is_same_v<decltype(KernelScheduleType::ScaleGranularityM), decltype(KernelScheduleType::ScaleGranularityN)> and
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()
>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>;
static constexpr auto ScaleGranularityM_ = KernelScheduleType::ScaleGranularityM;
static constexpr auto ScaleGranularityN_ = KernelScheduleType::ScaleGranularityN;
static constexpr auto ScalePromotionInterval_ = KernelScheduleType::ScalePromotionInterval;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
@ -1076,12 +1080,12 @@ struct CollectiveBuilder<
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsArrayOfPointersGemm = (
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelScheduleType> ||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelScheduleType>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now.");
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
@ -1091,10 +1095,9 @@ struct CollectiveBuilder<
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_, ScaleGranularityN_>>;
static constexpr bool IsCooperative = cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, KernelScheduleType> ||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelScheduleType>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
@ -1121,7 +1124,9 @@ struct CollectiveBuilder<
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_>;
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>,
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;

View File

@ -43,6 +43,7 @@
#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl"
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -50,6 +50,7 @@
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
#if !defined(__CUDACC_RTC__)
#include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp"
@ -59,5 +60,6 @@
#include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp"
#endif // !defined(__CUDACC_RTC__)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -223,6 +223,30 @@ public:
mma_count_ = 0;
}
}
/// scale (multiply_add) the results from the MMA accumulators to main accumulator without checking the counter.
CUTLASS_DEVICE
void scale(ElementAccumulator const &scale) {
scale_core(scale);
}
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale(const cute::Tensor<EngineScale, LayoutScale> &scale) {
scale_core(scale);
}
template <
class EngineScaleA,
class LayoutScaleA,
class EngineScaleB,
class LayoutScaleB>
CUTLASS_DEVICE
void scale(const cute::Tensor<EngineScaleA, LayoutScaleA> &scaleA, const cute::Tensor<EngineScaleB, LayoutScaleB> &scaleB) {
scale_core(scaleA, scaleB);
}
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
CUTLASS_DEVICE

File diff suppressed because it is too large Load Diff

View File

@ -204,6 +204,8 @@ public:
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static constexpr int NumProducerThreadEvents = 1;
using SmemLayoutAtomScale = Layout<Shape<decltype(cute::shape<0>(SwappedSmemLayoutAtomA{})), cute::Int<1>>>;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{})));
@ -1354,6 +1356,18 @@ public:
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in tensormaps_fence_acquire.");
}
}
template <class InputTensors, class ProblemShape_MNKL>
CUTLASS_DEVICE
InputTensors
tensors_perform_update(
InputTensors const& input_tensors,
[[maybe_unused]] Params const& mainloop_params,
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
[[maybe_unused]] int32_t next_batch) {
return input_tensors;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -116,6 +116,9 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
static constexpr int NumProducerThreadEvents = 1;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -749,6 +752,16 @@ struct CollectiveMma<
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
}
template <class InputTensors, class ProblemShape_MNKL>
CUTLASS_DEVICE
InputTensors
tensors_perform_update(
InputTensors const& input_tensors,
[[maybe_unused]] Params const& mainloop_params,
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
[[maybe_unused]] int32_t next_batch) {
return input_tensors;
}
};

View File

@ -759,6 +759,18 @@ struct CollectiveMma<
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
}
template <class InputTensors, class ProblemShape_MNKL>
CUTLASS_DEVICE
InputTensors
tensors_perform_update(
InputTensors const& input_tensors,
[[maybe_unused]] Params const& mainloop_params,
[[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl,
[[maybe_unused]] int32_t next_batch) {
return input_tensors;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -59,6 +59,7 @@ template <
class KernelSchedule,
int ScaleGranularityM_,
int ScaleGranularityN_,
int ScalePromotionInterval_,
class TileShape_,
class ElementA_,
class StrideA_,
@ -74,7 +75,7 @@ template <
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>,
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>,
TileShape_,
ElementA_,
StrideA_,
@ -93,7 +94,7 @@ struct CollectiveMma<
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_>;
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_, ScaleGranularityN_, ScalePromotionInterval_>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
@ -122,6 +123,8 @@ struct CollectiveMma<
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_;
static constexpr int ScalePromotionInterval = ScalePromotionInterval_;
static_assert(ScalePromotionInterval % 4 == 0, "ScalePromotionInterval must be a multiple of 4.");
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
@ -281,7 +284,9 @@ struct CollectiveMma<
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{});
implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval);
implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
@ -481,6 +486,38 @@ struct CollectiveMma<
}
}
template<
class EngineAccum,
class LayoutAccum,
class ScaleFactor
>
CUTLASS_DEVICE
void scale_if_needed(GmmaFP8Accumulation<EngineAccum, LayoutAccum>& accumulation, ScaleFactor scaleFactor) {
if constexpr (ScalePromotionInterval != 4) {
accumulation.scale_if_needed(scaleFactor);
}
else {
// avoid unnecessary tests when granularity is the finnest
accumulation.scale(scaleFactor);
}
}
template<
class EngineAccum,
class LayoutAccum,
class ScaleFactor1,
class ScaleFactor2
>
CUTLASS_DEVICE
void scale_if_needed(GmmaFP8Accumulation<EngineAccum, LayoutAccum>& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) {
if constexpr (ScalePromotionInterval != 4) {
accumulation.scale_if_needed(scaleFactor1, scaleFactor2);
}
else {
// avoid unnecessary tests when granularity is the finnest
accumulation.scale(scaleFactor1, scaleFactor2);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
@ -575,7 +612,7 @@ struct CollectiveMma<
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA));
GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA));
warpgroup_fence_operand(accumulation());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
@ -584,7 +621,13 @@ struct CollectiveMma<
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation.prepare_if_needed()) {
if constexpr (ScalePromotionInterval != 4) {
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
}
else {
// Always zero out the accumulator for finest granularity
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
@ -624,16 +667,16 @@ struct CollectiveMma<
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_if_needed(scale_ab);
scale_if_needed(accumulation, scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC);
scale_if_needed(accumulation, tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleBViewAsC);
scale_if_needed(accumulation, tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC);
}
++smem_pipe_read;
@ -677,7 +720,13 @@ struct CollectiveMma<
}
}
if (accumulation.prepare_if_needed()) {
if constexpr (ScalePromotionInterval != 4) {
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
}
else {
// Always zero out the accumulator for finest granularity
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
@ -699,16 +748,16 @@ struct CollectiveMma<
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_if_needed(scale_ab);
scale_if_needed(accumulation, scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC);
scale_if_needed(accumulation, tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleBViewAsC);
scale_if_needed(accumulation, tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC);
}
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
@ -718,18 +767,21 @@ struct CollectiveMma<
++smem_pipe_release;
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_residue_if_needed(scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_residue_if_needed(tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
if constexpr (ScalePromotionInterval != 4) {
// residues only exists when granularity is not the finnest
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0];
accumulation.scale_residue_if_needed(scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
accumulation.scale_residue_if_needed(tCrScaleBViewAsC);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC);
}
}
warpgroup_fence_operand(accumulation());

View File

@ -120,10 +120,38 @@ template<
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
int ScaleGranularityM = 0,
int ScaleGranularityN = 0
int ScaleGranularityM_ = 0,
int ScaleGranularityN_ = 0,
// `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling
// It is required to be a multiple of 4 and specified in terms of number of MMA instructions
// in the reduction dimension. i.e for FP8 kernels, it is
// ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default
int ScalePromotionInterval_ = 4
>
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative {
constexpr static int ScaleGranularityM = ScaleGranularityM_;
constexpr static int ScaleGranularityN = ScaleGranularityN_;
constexpr static int ScalePromotionInterval = ScalePromotionInterval_;
};
template<
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
int ScaleGranularityM_,
int ScaleGranularityN_,
// `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling
// It is required to be a multiple of 4 and specified in terms of number of MMA instructions
// in the reduction dimension. i.e for FP8 kernels, it is
// ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default
int ScalePromotionInterval_ = 4
>
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative {
constexpr static int ScaleGranularityM = ScaleGranularityM_;
constexpr static int ScaleGranularityN = ScaleGranularityN_;
constexpr static int ScalePromotionInterval = ScalePromotionInterval_;
};
// Policies to opt into mixed type GEMMs
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
@ -310,12 +338,17 @@ template<
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
int ScaleGranularityM = 0,
int ScaleGranularityN = 0
int ScaleGranularityN = 0,
// `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling
// It is required to be a multiple of 4 and specified in terms of number of MMA instructions
// in the reduction dimension. i.e for FP8 kernels, it is
// ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default
int ScalePromotionInterval = 4
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>>,
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN, ScalePromotionInterval>>,
"KernelSchedule must be one of the warp specialized policies");
};
@ -327,6 +360,7 @@ template<
>
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
constexpr static int Stages = Stages_;
constexpr static int PipelineAsyncMmaStages = 1;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
@ -391,6 +425,26 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput {
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
// For FP8 kernels with Block Scaling
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative,
// `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value
// `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is
// `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N.
int ScaleGranularityM = 0,
int ScaleGranularityN = 0,
int ScalePromotionInterval = 4
>
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_same_v<KernelSchedule, KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>>,
"KernelSchedule must be one of the warp specialized policies");
};
template<
int SchedulerPipelineStageCount_,
@ -411,6 +465,14 @@ struct KernelTmaWarpSpecializedBlockScaledSm100 final {
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
};
template<
int SchedulerPipelineStageCount_,
int AccumulatorPipelineStageCount_
>
struct KernelTmaWarpSpecializedMmaTransformSm100 final {
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
};
// InputTransform GEMM
@ -484,6 +546,13 @@ struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {};
struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {};
struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {};
///////////////////////////////////////////////////////////////////////////////////////////////////////
// SM100 Blockwise GEMM Dispatch Policies
///////////////////////////////////////////////////////////////////////////////////////////////////////
struct KernelScheduleSm100Blockwise : KernelScheduleSm100 {};
struct KernelTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100Blockwise {};
///////////////////////////////////////////////////////////////////////////////////////////////////////
// SM100 Planar Complex GEMM Dispatch Policies
///////////////////////////////////////////////////////////////////////////////////////////////////////
@ -530,6 +599,9 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelSch
struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { };
struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { };
struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { };
///////////////////////////////////////////////////////////////////////////////////////////////////////
// SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies
///////////////////////////////////////////////////////////////////////////////////////////////////////
// BlockScaled Dense GEMM + (Ptr Array or Group GEMM)
struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {};
struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {};
@ -544,8 +616,6 @@ struct KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, K
struct KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { };
struct KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { };
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
template<
int Stages_,
@ -561,7 +631,20 @@ struct MainloopSm100TmaUmmaWarpSpecialized {
constexpr static bool IsOverlappingAccum = false;
};
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
template<
int Stages_,
int SchedulerPipelineStageCount_,
int AccumulatorPipelineStageCount_,
class ClusterShape_ = Shape<_1,_1,_1>
>
struct MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm100;
using Schedule = KernelTmaWarpSpecializedMmaTransformSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
constexpr static bool IsOverlappingAccum = false;
};
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
template<

View File

@ -65,6 +65,7 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp"

File diff suppressed because it is too large Load Diff

View File

@ -128,10 +128,11 @@ public:
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{}));
static constexpr uint32_t NumMmaThreads = size(TiledMma{});
static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
@ -434,7 +435,8 @@ public:
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = size(TiledMma{});
mainloop_pipeline_params.num_consumers = NumMmaThreads;
mainloop_pipeline_params.num_producers = NumProducerThreads;
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
@ -575,6 +577,7 @@ public:
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
if (did_batch_change) {
load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch);
collective_mainloop.tensormaps_fence_acquire(input_tensormaps);
}

View File

@ -131,6 +131,7 @@ public:
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
@ -443,6 +444,7 @@ public:
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.num_producers = NumProducerThreads;
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
@ -607,6 +609,7 @@ public:
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
if (did_batch_change) {
load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch);
collective_mainloop.tensormaps_fence_acquire(input_tensormaps);
}