CUTLASS 3.6.0 (#1850)

* v3.6

* update changelog

* update readme

* fix typo

* fixing typos

* hopper gemm with weight prefetch

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2024-10-09 12:33:27 -07:00
committed by GitHub
parent 0837a2a00a
commit cc3c29a81a
354 changed files with 105943 additions and 8203 deletions

View File

@ -42,7 +42,8 @@
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_size.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
@ -56,6 +57,7 @@
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
#include "cutlass/layout/matrix.h"
#include "cutlass/epilogue/thread/linear_combination.h"
@ -657,7 +659,9 @@ struct Testbed {
}
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n();
int64_t bytes = cutlass::bits_to_bytes(
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
options.problem_size.m() * options.problem_size.n());
double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9);
double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30);

View File

@ -303,14 +303,14 @@ bool initialize_block(
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
scope_max = Element(2);
scope_min = Element(0);
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
scope_max = Element(2);
scope_min = Element(-2);
} else {
scope_max = 8;
scope_min = -8;
scope_max = Element(8);
scope_min = Element(-8);
}
cutlass::reference::device::BlockFillRandomUniform(

View File

@ -111,7 +111,7 @@ public:
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
struct PipelineStorage : cute::aligned_struct<16, _2> {
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;

View File

@ -50,7 +50,7 @@ struct PermuteTraits {};
using X = Underscore;
// Reshape a rank-2 shape into a multidimensional shape.
// Input:
// Input:
// shape = (A, B, ...)
// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...)
// Output:
@ -76,12 +76,12 @@ reshape(Shape const& shape, TargetShape const& target_shape)
// - sub-modes corresponding to the implied multidimensional shape of the source tensor
// - strides accounting for the permutation operation being performed
template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto
constexpr auto
make_permute_layout(Layout<Shape,Stride> const& layout) {
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
return select<1,0,2>(make_permute_layout<Permute, false>(select<1,0,2>(layout)));
}
else {
@ -129,23 +129,24 @@ inverse(Permutation const & perm) {
template<class T>
using inverse_t = decltype(inverse(T{}));
// Given a rank-2 layout of tensor that is assumed to have been permuted,
// Given a rank-2 layout of tensor that is assumed to have been permuted,
// compute the original rank-2 layout of the tensor prior to the permutation.
// This is needed to form the correct input to the standalone permutation kernel.
// This is needed to form the correct input to the standalone permutation kernel.
template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto
constexpr auto
make_original_layout(Layout<Shape,Stride> const& layout) {
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
return select<1,0,2>(make_original_layout<Permute, false>(select<1,0,2>(layout)));
}
else {
using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{}));
using IndexOrder = typename PermuteTraits<Permute>::IndexOrder;
auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get<i>(re_shape); });
using OrigOrder = conditional_t<cutlass::gemm::detail::is_major<0,Stride>(), seq<0,1,2>, seq<1,0,2>>;
auto orig_shape = select(flatten(reshape(layout.shape(), ShapeProfile{})), IndexOrder{});
// print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n");
// print("Original shape: "); print(orig_shape); print("\n");
return make_ordered_layout(product_each(orig_shape), OrigOrder{});
@ -202,7 +203,7 @@ struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D>>
};
template<int D>
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
{
static constexpr bool kBatched = true;
using ShapeProfile = Shape<Shape<X,Int<D>>, Shape<X>, Shape<X>>;
@ -222,7 +223,7 @@ struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D>>
};
template<int D>
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
{
static constexpr bool kBatched = true;
using ShapeProfile = Shape<Shape<X>, Shape<X,Int<D>>, Shape<X>>;

View File

@ -0,0 +1,701 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example shows how to perform INT4 x FP8 GEMM and scale up the INT4 weight during dequantization. It uses a look-up table to avoid the multiplications
between INT4 and FP8. To trigger this method, use cutlass::Array<ElementScale, 8> as the scale type in the collective's arguments.
However, this algorithm requires changes to the encoding of INT4 weights and scale factors. These changes must happen before launching the GEMM. See the helper functions
`unify_quant_encoding`, `initialize_packed_scale`, and header `fp8_packed_scale.hpp` for details.
In a nutshell, the positive values of INT4 weights need to be encoded in the same way as negative values except for the sign bit. For each scale factor,
8 negative results (-8 x scale, -7 x scale, ... -1 x scale) are packed together, forming a cutlass::Array<ElementScale, 8> value.
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
as illustrated in this example.
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size
equal to the gemm problem K.
Limitations:
1) Only supports INT4 x { FP8, INT8, UINT8 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported.
2) The INT4 weights and scale factors have additional encoding requirements.
3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
4) The scales must have the same layout and groupsize.
5) The groupsize must be greater or equal to the tile shape k.
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
Optimizing suggestions:
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
Examples:
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
matrix (group size is the same as the gemm k dimension).
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "helper.h"
#include "unfused_weight_dequantize.hpp"
#include "packed_scale.hpp"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
// A matrix configuration
using ElementA = MmaType; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using ElementScale = MmaType;
using ElementZero = ElementScale; // only for verify
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementAccumulator,
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C matrix.
// We can enable this if beta == 0 by changing ElementC to void below.
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
>::CollectiveOp;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8> >, LayoutB_Transpose, AlignmentB,
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideC = typename GemmKernelScaleOnly::StrideC;
using StrideD = typename GemmKernelScaleOnly::StrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideC_ref stride_C_ref;
StrideD stride_D;
StrideD_ref stride_D_ref;
uint64_t seed;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
StrideS_ref stride_S_ref;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementB> block_B_modified;
cutlass::DeviceAllocation<ElementA> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8>> block_scale_packed;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 10;
int m = 5120, n = 4096, k = 4096;
int g = 128;
int l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("g", g);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "55_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
<< " --g=<int> The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <typename Element>
bool initialize_quant_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
// Here the encodings of positive values and negative values are unified (except for the sign bit).
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
bool unify_quant_encoding(
cutlass::DeviceAllocation<cutlass::int4b_t> const& block_in,
cutlass::DeviceAllocation<cutlass::int4b_t>& block_out) {
using StorageType = cutlass::int4b_t::Storage;
if (block_in.size() != block_out.size()) {
std::cerr << "block_in and block_out must have same size.\n";
return false;
}
constexpr int pack = sizeof_bits_v<StorageType> / 4;
std::vector<StorageType> data(block_in.size() / pack);
cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack);
for (auto&& d : data) {
StorageType out = 0;
StorageType mask = 0x0f;
for (int i = 0; i < pack; ++i) {
cutlass::int4b_t curr;
curr.storage = (d >> (i * 4)) & 0x0f;
switch (curr) {
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
default: break;
}
out |= (curr.storage << (4 * i)) & mask;
mask <<= 4;
}
d = out;
}
cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2);
return true;
}
template <class Element>
bool initialize_scale(
cutlass::DeviceAllocation<Element>& block,
Options const& options) {
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
float const max_dequant_val = 4.f;
float const min_dequant_val = 0.5f;
float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
bool initialize_packed_scale(
cutlass::DeviceAllocation<ElementScale> const& block_in,
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8> > & block_out) {
std::vector<ElementScale> data_in(block_in.size());
std::vector<cutlass::Array<ElementScale, 8> > data_out(block_in.size());
try {
block_in.copy_to_host(data_in.data());
} catch (cutlass::cuda_exception const& e)
{
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
return false;
}
for (size_t i = 0; i < block_in.size(); ++i)
{
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
// std::cout << data_in[i] << ":" << std::hex << static_cast<uint16_t>(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast<uint16_t>((-data_in[i]).storage) << std::endl;
}
try {
block_out.copy_from_host(data_out.data());
} catch (cutlass::cuda_exception const& e)
{
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
return false;
}
return true;
}
template <class Element>
bool initialize_zero(
cutlass::DeviceAllocation<Element>& block,
Options const& options) {
std::vector<Element> stage(block.size(), Element(0.0f));
block.copy_from_host(stage.data());
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
// Reverse stride here due to swap and transpose
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
block_A.reset(a_coord.product());
block_B.reset(b_coord.product());
block_B_modified.reset(b_coord.product());
block_B_dq.reset(b_coord.product());
block_C.reset(c_coord.product());
block_D.reset(c_coord.product());
block_ref_D.reset(c_coord.product());
block_scale.reset(scale_k * options.l * options.n);
block_scale_packed.reset(scale_k * options.l * options.n);
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
unify_quant_encoding(block_B, block_B_modified);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_packed_scale(block_scale, block_scale_packed);
initialize_zero(block_zero, options);
auto layout_B = make_layout(shape_b, stride_B);
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Args>
Args args_from_options(Options const& options)
{
// Swap the A and B tensors, as well as problem shapes here.
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
bool verify(Options const& options) {
//
// Compute reference output
//
// In this example, we use the GPU default kernels as a reference (unfused scale).
// This avoids numerical differences due to different accumulation order.
// Again, due to numerical differences, we must use fast acc here when the mma type is
// FP8 as the fused implementation only supports fast acc at the moment.
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
ScheduleRef
>::CollectiveOp;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// compare_reference
ElementD const epsilon(1e-2f);
ElementD const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<typename Gemm::Arguments>(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
if (options.g == options.k) {
std::cout << "Running in per-column scale mode." << std::endl;
} else {
std::cout << "Running in group scale mode." << std::endl;
}
run<GemmScaleOnly>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -53,14 +53,18 @@
equal to the gemm problem K.
Limitations:
1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}.
2) The narrow type must always be in K-major format.
3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
4) The scales and the zeros must have the same layout and groupsize.
1) The narrow type must always be in K-major format.
2) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
3) The scales and the zeros must have the same layout and groupsize.
4) The groupsize must be greater or equal to tile shape k.
5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
Optimizing suggestions:
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
2) Try avoid using scale or zero mode cause the computations will be the bottleneck.
Examples:
@ -94,11 +98,8 @@
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "helper.h"
#include "unfused_weight_dequantize.hpp"
@ -117,8 +118,8 @@ enum GemmMode {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
using MmaType = cutlass::half_t;
using QuantType = cutlass::float_e4m3_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
// A matrix configuration
@ -154,8 +155,8 @@ using ElementAccumulator = float; // E
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_256,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
@ -268,14 +269,14 @@ using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
StrideS_ref stride_S_ref;
cutlass::HostTensor<MmaType, LayoutA> tensor_A;
cutlass::HostTensor<QuantType, LayoutB> tensor_B;
cutlass::HostTensor<MmaType, LayoutB> tensor_B_dq;
cutlass::HostTensor<ElementScale, LayoutScale> tensor_scale;
cutlass::HostTensor<ElementZero, LayoutScale> tensor_zero;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementA> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
@ -290,7 +291,7 @@ struct Options {
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 1000;
int iterations = 10;
int mode = 2;
int m = 5120, n = 4096, k = 4096;
int g = 128;
@ -368,9 +369,9 @@ struct Result
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element, class Layout>
template <class Element>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
double scope_max, scope_min;
@ -393,34 +394,35 @@ bool initialize_tensor(
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <typename Element, typename Layout>
template <typename Element>
bool initialize_quant_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <class Element, class Layout>
template <class Element>
bool initialize_scale(
cutlass::TensorView<Element, Layout> view,
const Options &options) {
cutlass::DeviceAllocation<Element>& block,
Options const& options) {
if (options.mode == GemmMode::ConvertOnly) {
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
cutlass::reference::host::TensorFill(view, Element(1.0f));
std::vector<Element> stage(block.size(), Element(1.0f));
block.copy_from_host(stage.data());
}
else {
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
@ -430,32 +432,33 @@ bool initialize_scale(
float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
}
return true;
}
template <class Element, class Layout>
template <class Element>
bool initialize_zero(
cutlass::TensorView<Element, Layout> view,
const Options &options) {
cutlass::DeviceAllocation<Element>& block,
Options const& options) {
if (options.mode == GemmMode::ScaleWithZeroPoint) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2.0f, -2.0f);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
} else {
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
cutlass::reference::host::TensorFill(view, Element(0.0f));
std::vector<Element> stage(block.size(), Element(0.0f));
block.copy_from_host(stage.data());
}
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
void initialize(Options const& options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
const int scale_k = (options.k + options.g - 1) / options.g;
int const scale_k = (options.k + options.g - 1) / options.g;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
// Reverse stride here due to swap and transpose
@ -469,27 +472,21 @@ void initialize(const Options &options) {
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_B_dq.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
block_A.reset(a_coord.product());
block_B.reset(b_coord.product());
block_B_dq.reset(b_coord.product());
block_C.reset(c_coord.product());
block_D.reset(c_coord.product());
block_ref_D.reset(c_coord.product());
tensor_scale.resize({scale_k * options.l, options.n});
tensor_zero.resize({scale_k * options.l, options.n});
block_scale.reset(scale_k * options.l * options.n);
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(tensor_A.host_view(), seed + 2022);
initialize_quant_tensor(tensor_B.host_view(), seed + 2021);
initialize_tensor(tensor_C.host_view(), seed + 2020);
initialize_scale(tensor_scale.host_view(), options);
initialize_zero(tensor_zero.host_view(), options);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_scale.sync_device();
tensor_zero.sync_device();
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
auto layout_B = make_layout(shape_b, stride_B);
@ -498,37 +495,36 @@ void initialize(const Options &options) {
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g);
tensor_B_dq.sync_host();
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Args>
Args args_from_options(const Options &options)
Args args_from_options(Options const& options)
{
// Swap the A and B tensors, as well as problem shapes here.
if (options.mode == GemmMode::ConvertOnly) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
{block_B.get(), stride_B, block_A.get(), stride_A},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
else if (options.mode == GemmMode::ScaleOnly) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
} else {
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
@ -542,7 +538,7 @@ bool verify(const Options &options) {
//
// In this example, we use the GPU default kernels as a reference (unfused scale)
// This is to avoid numerical differences from different accumulation order.
// This avoids numerical differences due to different accumulation order.
// Again, due to numerical differences, we must use fast acc here when the mma type is
// FP8 as the fused implementation only supports fast acc at the moment.
@ -581,8 +577,8 @@ bool verify(const Options &options) {
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B},
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref}
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
};
// Run the gemm where the scaling is performed outside of the kernel.
@ -594,11 +590,9 @@ bool verify(const Options &options) {
CUTLASS_CHECK(gemm_ref.run());
// compare_reference
tensor_D.sync_host();
tensor_ref_D.sync_host();
const ElementD epsilon(1e-2f);
const ElementD non_zero_floor(1e-4f);
bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor);
ElementD const epsilon(1e-2f);
ElementD const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
}

View File

@ -55,5 +55,16 @@ cutlass_example_add_executable(
TEST_SCALE_ZERO_GROUPED
TEST_SCALE_RESIDUE
TEST_SCALE_ZERO_RESIDUE
TEST_ALPHA_BETA
# TEST_ALPHA_BETA
)
cutlass_example_add_executable(
55_hopper_int4_fp8_gemm
55_hopper_int4_fp8_gemm.cu
TEST_COMMAND_OPTIONS
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
TEST_SCALE_RESIDUE
# TEST_ALPHA_BETA
)

View File

@ -11,6 +11,8 @@ This first version only supports mixed type GEMMs using TMA.
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
We are currently optimizing the following cases:
1. Memory bound cases for all types

View File

@ -0,0 +1,132 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <cstdint>
#include "cutlass/float8.h"
namespace cutlass
{
template<typename T>
class packed_scale_t {
public:
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
cute::is_same_v<T, cutlass::uint8_t> ||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
cute::is_same_v<T, cutlass::float_e5m2_t>,
"only 8 bit arithmetic types are supported.");
CUTLASS_HOST_DEVICE
explicit packed_scale_t(T val) {
if constexpr (!cute::is_unsigned_v<T>) {
// Only pack negative values. The positive values are generated in flight in the mainloop.
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
}
else {
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
}
}
CUTLASS_HOST_DEVICE
packed_scale_t() = default;
CUTLASS_HOST_DEVICE
explicit operator float() const {
return float(get());
}
CUTLASS_HOST_DEVICE
bool operator==(packed_scale_t const& rhs) const {
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
}
CUTLASS_HOST_DEVICE
bool operator!=(packed_scale_t const& rhs) const {
return !(*this == rhs);
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() + rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() - rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() * rhs.get());
}
CUTLASS_HOST_DEVICE
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
return packed_scale_t(lhs.get() / rhs.get());
}
private:
using Storage = uint32_t;
using Stage = uint8_t;
Storage storage[2] {};
CUTLASS_HOST_DEVICE
static Storage pack4(T c1, T c2, T c3, T c4) {
Storage result = 0;
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
return result;
}
CUTLASS_HOST_DEVICE
T get() const {
auto stage = static_cast<Stage>(storage[0] >> 8);
#if defined(__CUDA_ARCH__)
return reinterpret_cast<T const&>(stage);
#else
T tmp;
std::memcpy(&tmp, &stage, sizeof(Stage));
return tmp;
#endif
}
CUTLASS_HOST_DEVICE
T get(int idx) const {
Stage stage;
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
#if defined(__CUDA_ARCH__)
return reinterpret_cast<T const&>(stage);
#else
T tmp;
std::memcpy(&tmp, &stage, sizeof(Stage));
return tmp;
#endif
}
};
}

View File

@ -32,7 +32,7 @@
/*! \file
\brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
warp-specialized cooperative kernel.
The new feature showcased in this example is on-the-fly modification of TMA descriptors
to move between batches (represented by l).
@ -547,3 +547,4 @@ int main(int argc, char const **args) {
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -91,9 +91,9 @@
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

View File

@ -26,6 +26,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (NOT MSVC)
cutlass_example_add_executable(
59_ampere_gather_scatter_conv
ampere_gather_scatter_conv.cu
@ -34,3 +36,5 @@ cutlass_example_add_executable(
if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND)
target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX)
endif()
endif()

View File

@ -0,0 +1,534 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper GEMM + Top-K + Softmax fusion
This example illustrates how to use the LinCombTopKSoftmaxCol EVT node to fuse
Top-K and Softmax into the GEMM epilogue, with certain assumptions made.
Those assumptions are as:
1. Fusion is over the N dimension.
2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be
compiled to support both.)
3. The GEMM tile shape along N is greater than or equal to problem size
along N.
The example runs the fused GEMM kernel, along with a standard unfused host reference, and
manually performs Top-K and softmax, and compares the error between tensors.
Note that some numerical error (smaller than 1e-5) is to be expected, but this is true
in most efficient reduction kernels, because floating point addition is not necessarily
associative.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
static constexpr int TopK = 2;
static constexpr bool EnableTopKSoftmax = TopK > 1;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 1;
// D matrix configuration
using ElementD = cutlass::half_t; // Element type for C and D matrix operands
using LayoutD = cutlass::layout::RowMajor; // Layout type for output
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of output in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
// Top-K + Softmax fusion operation
using FusionOperation = std::conditional_t<EnableTopKSoftmax,
typename cutlass::epilogue::fusion::LinCombTopKSoftmaxCol<TopK, ElementD, ElementCompute>,
typename cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute>
>;
// The fusion op only allows for epilogue tiles matching the mainloop tile.
using EpilogueTileType = decltype(cute::take<0,2>(TileShape{}));
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideD stride_D;
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
int iterations = 1000;
int m = 16, n = 8, k = 64, l = 1;
double eps = 1e-5;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("eps", eps);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "61_hopper_gemm_with_topk_and_softmax\n\n"
<< " Hopper FP8 GEMM with Top-K and softmax fusion.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --eps=<float> Threshold of numerical verification. Default: 1e-5.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "61_hopper_gemm_with_topk_and_softmax" << " --m=16 --n=8 --k=1024 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
float alpha() const {
return 1.f / static_cast<float>(k);
}
};
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, /* max = */ 1, /* min = */ -1, /* bits = */ 2);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
initialize_tensor(tensor_A.host_view(), seed + 2022);
initialize_tensor(tensor_B.host_view(), seed + 2023);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_D.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
{
{options.alpha(), 0.f}, // alpha, beta
nullptr, stride_D,
tensor_D.device_data(), stride_D
}
};
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
using unused_t = decltype(D);
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
unused_t,
decltype(D),
unused_t, // bias
unused_t, // aux
unused_t, // valpha
unused_t // vbeta
> epilogue_params;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha();
epilogue_params.beta = 0.f;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
if constexpr (EnableTopKSoftmax) {
// top-K + softmax
for (int i = 0; i < options.m; ++i) {
// Find Top-K
cutlass::Array<ElementAccumulator, TopK> top_k;
top_k.fill(-cutlass::platform::numeric_limits<ElementCompute>::infinity());
for (int j = 0; j < options.n; ++j) {
auto val = static_cast<ElementAccumulator>(tensor_ref_D.host_view().ref().at({i, j}));
for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) {
if (val > top_k[top_k_idx]) {
// Shift down
for (int l = TopK - 1; l > top_k_idx; --l) {
top_k[l] = top_k[l - 1];
}
top_k[top_k_idx] = val;
break;
}
}
}
// This formulation of top-K + softmax only works when it is
// guaranteed that none of the top-K elements are repeated!
// If this is the case, the device kernel can also make mistakes, because
// A. Once the top-K values are reduced, and the operation is being applied,
// there is no way to tell repeated elements apart, so none are masked.
// B. The softmax sum of exps will be incorrect (because the repeated elements
// are not repeated in it.)
ElementAccumulator max = top_k[0];
ElementAccumulator sum = ElementAccumulator(0.f);
for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) {
sum = sum + cutlass::fast_exp(top_k[top_k_idx] - max);
}
for (int j=0; j < options.n; ++j) {
auto val = tensor_ref_D.host_view().ref().at({i, j});
if (val < top_k[TopK - 1]) {
tensor_ref_D.host_view().ref().at({i, j}) = static_cast<ElementD>(0.f);
} else {
// Softmax
auto softmax_val = cutlass::fast_exp(val - max) / sum;
tensor_ref_D.host_view().ref().at({i, j}) = static_cast<ElementD>(softmax_val);
}
}
}
}
// compare_reference
tensor_D.sync_host();
double err = cutlass::reference::host::TensorRelativeErrorMetric(
tensor_D.host_view(),
tensor_ref_D.host_view());
bool passed = err < options.eps;
if (options.m <= 32 && options.n <= 32) {
std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n";
std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n";
}
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options) {
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,32 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
61_hopper_gemm_with_topk_and_softmax
61_hopper_gemm_with_topk_and_softmax.cu
)

View File

@ -0,0 +1,596 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper Sparse GEMM example.
This example demonstrates how to construct and run a structured sparse GEMM kernel
on NVIDIA Hopper architecture.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = float; // Element type for C and D matrix operands
using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel
using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy
using ProblemShape = Shape<int,int,int,int>;
// Sparse kernel setup
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutTagC, AlignmentC,
ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
ElementA, LayoutTagA, AlignmentA,
ElementB, LayoutTagB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference (dense) kernel setup
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShapeRef, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutTagC, AlignmentC,
ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutTagA, AlignmentA,
ElementB, LayoutTagB, AlignmentB,
ElementAccumulator,
TileShapeRef, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloopRef,
CollectiveEpilogue
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
// Layouts
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
// Layouts for reference (non-sparse) tensors
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
// Offline compressor kernel
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape,
ElementA,
LayoutTagA,
SparseConfig,
cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
//
// Data members
//
ProblemShape problem_shape;
StrideA stride_A;
StrideA stride_A_compressed;
StrideE stride_E;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
LayoutA layout_A;
LayoutE layout_E;
uint64_t seed;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A_compressed;
cutlass::DeviceAllocation<typename Gemm::CollectiveMainloop::ElementE> block_E;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D_ref;
#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
float alpha, beta;
int iterations;
int m, n, k, l;
Options():
help(false),
m(5120), n(4096), k(16384), l(1),
alpha(1.f), beta(0.f),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "62_hopper_sparse_gemm\n\n"
<< " Hopper Sparse GEMM example.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the L extent of the GEMM (batch size)\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = Element(2);
scope_min = Element(0);
} else if (bits_input <= 8) {
scope_max = Element(2);
scope_min = Element(-2);
} else {
scope_max = Element(8);
scope_min = Element(-8);
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Make A structured sparse by replacing elements with 0 and compress it
bool sparsify_and_compress()
{
auto [M, N, K, L] = problem_shape;
CompressorUtility compressor_utility(problem_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
block_A_compressed.reset(M * KC * L);
block_E.reset(ME * KE * L);
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L));
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L));
// Random sparsification is performed on host
std::vector<ElementA> block_A_host(block_A.size());
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast<int>(seed + 2024));
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments {
problem_shape,
{ block_A.get(),
stride_A,
block_A_compressed.get(),
block_E.get() },
{hw_info} };
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
bool initialize(Options const& options) {
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
auto [M, N, K, L] = problem_shape;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
// Allocate memory for tensors
block_A.reset(M * K * L);
block_B.reset(N * K * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_D_ref.reset(M * N * L);
// Fill input tensors with data
initialize_block(block_A, seed + 2021);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2023);
// Replace 0 in A with 1 to avoid metadata changes
std::vector<ElementA> block_A_host(block_A.size());
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0);
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
if (!sparsify_and_compress()) {
return false;
};
// Build the compressed/metadata layouts
layout_A = SparseConfig::fill_layoutA(problem_shape);
layout_E = SparseConfig::fill_layoutE(problem_shape);
return true;
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments make_args(Options const& options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_shape,
{ block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E },
{ { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) },
block_C.get(), stride_C, block_D.get(), stride_D }
};
return arguments;
}
typename GemmRef::Arguments make_args_ref(Options const& options)
{
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_shape,
{ block_A.get(), stride_A, block_B.get(), stride_B },
{ { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) },
block_C.get(), stride_C, block_D_ref.get(), stride_D }
};
return arguments;
}
template<class Engine, class Layout>
void print_device_tensor(cute::Tensor<Engine, Layout> const& t)
{
// Assumes size = cosize, i.e. compact tensor
std::vector<typename Engine::value_type> data_host(t.size());
cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size());
auto t_host = cute::make_tensor(data_host.data(), t.layout());
cute::print_tensor(t_host);
}
bool verify(Options const& options) {
CUDA_CHECK(cudaDeviceSynchronize());
bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size());
#if 0
if (!passed) {
auto [M, N, K, L] = problem_shape;
CompressorUtility compressor_utility(problem_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A));
cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed));
cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E));
cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast<CollectiveMainloop::ElementEMmaSparsity>(layout_E)));
cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B));
cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C));
cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D));
cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D));
}
#endif
return passed;
}
template<typename Gemm>
struct Runner
{
using Arguments = typename Gemm::Arguments;
Runner(Arguments args): arguments(args) {
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
workspace.reset(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
}
void run() {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
void benchmark(Options const& options) {
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
run();
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double gflops = options.gflops(avg_runtime_ms / 1000.0);
std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << gflops << std::endl;
}
}
Gemm gemm;
Arguments arguments;
cutlass::device_memory::allocation<uint8_t> workspace;
};
/// Execute the example (verification and timing)
void run(Options &options) {
bool init = initialize(options);
if (!init) {
std::cout << "Initialization failure" << std::endl;
exit(EXIT_FAILURE);
}
Runner<Gemm> gemm(make_args(options));
Runner<GemmRef> gemm_ref(make_args_ref(options));
gemm.run();
gemm_ref.run();
bool passed = verify(options);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if (!passed) {
exit(EXIT_FAILURE);
}
std::cout << "Sparse GEMM:" << std::endl;
gemm.benchmark(options);
std::cout << "Dense GEMM:" << std::endl;
gemm_ref.benchmark(options);
}
#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) {
std::cerr << "This example requires CUDA 12.2 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
run(options);
#endif
return EXIT_SUCCESS;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,36 @@
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Sparse kernel in this example triggers an ICE in gcc 7.5
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
cutlass_example_add_executable(
62_hopper_sparse_gemm
62_hopper_sparse_gemm.cu
)
endif()

View File

@ -0,0 +1,500 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper FP8 GEMM + L2 Weight Prefetch
This example implements a non-persistent warp-specialized GEMM kernel for the Hopper
architecture with programmatic dependent launch (PDL) enabling prefetching weights into
L2 cache.
For more information about dependent launch refer to the CUDA programming guide:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization
In some cases, PDL can result in a window where a previous kernel is not actively utilizing
DRAM, and the next kernel sits idle until the previous finishes. During this window, the next
kernel can begin loading a non-dependent operand (i.e. weights in a linear projection are
typically static) and cache it in L2.
The kernel and collective mainloop assume operand `A` corresponds to weights and operand `B`
corresponds to activations (so we can have very small batch/token count).
After initialization, the prefetch warp starts loading K tiles of `A` into an unused portion
of shared memory, and loads up to half of all K tiles that the same CTA would eventually load.
The exact number of K tiles loaded is determined by `args.mainloop.prefetch_ratio` \in
[0.0, 1.0]. Smaller values result in less prefetching, and larger values result in more.
Negative values result in a "best-effort" prefetch, meaning prefetcher will stop issuing weight
loads as soon as the activation DMA warp starts loading (as soon as it is signaled that the
previous kernel has flushed its memory.)
The DMA warp responsible for loading `A` will also begin loading K tiles until it fills up
the available shared memory.
The DMA warp responsible for loading `B` will wait until activations are flushed to global
memory by the preceding kernel.
Another mainloop parameter, `args.mainloop.overlap_ratio` \in [0.0, 1.0] determines how early
the next kernel (the one doing the prefetch) is launched. Smaller values result in greater
overlap, and larger values result in smaller overlap. Negative values disable PDL completely,
meaning there will be no overlap. This will make prefetch ineffective.
These two runtime parameters should be tuned per problem size and GEMM config combination, and
if feasible, per-operation in an entire layer or model.
NOTE: you must build this target with the following flag to enable Grid Dependency Control
instructions (GDC) in CUTLASS:
- CUTLASS_ENABLE_GDC_FOR_SM90
To lock persistence mode, power (350W), clocks (1005MHz) for evaluation (assumes device 0 and H100)
$ sudo nvidia-smi -pm 1 -i 0
$ sudo nvidia-smi -i 0 -pl 350
$ sudo nvidia-smi -i 0 -lgc 1005
Example:
$ mkdir build && cd build
$ cmake .. -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1
$ cd examples/63_hopper_gemm_with_weight_prefetch
$ make
$ ./63_hopper_gemm_with_weight_prefetch --p=0.5 --o=0.5
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "collective/dispatch_policy_extra.hpp"
#include "collective/builder.hpp"
#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp"
#include "helper.h"
#include "gemm_with_weight_prefetch_commandline.hpp"
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C matrix configuration
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size
// Cluster_N > 1 is not supported yet.
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Extract information from Gemm kernel.
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
using LayoutScalar = cutlass::layout::PackedVectorLayout;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
double eff_bw;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
double eff_bw = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), eff_bw(eff_bw), status(status), error(error), passed(false)
{}
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
uint64_t seed) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
tensor_A.resize(a_coord);
tensor_B.resize(b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
initialize_tensor(tensor_A.host_view(), seed + 2022);
initialize_tensor(tensor_B.host_view(), seed + 2023);
initialize_tensor(tensor_C.host_view(), seed + 2024);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
{
{}, // epilogue.thread
tensor_C.device_data(), stride_C,
tensor_D.device_data(), stride_D
}
};
auto &fusion_args = arguments.epilogue.thread;
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = scalar_alpha.device_data();
fusion_args.beta_ptr = scalar_beta.device_data();
arguments.mainloop.overlap_ratio = options.overlap_ratio;
arguments.mainloop.prefetch_ratio = options.prefetch_ratio;
return arguments;
}
bool verify(const Options &options) {
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
auto B = cute::make_tensor(tensor_B.host_data(),
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
auto C = cute::make_tensor(tensor_C.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
auto D = cute::make_tensor(tensor_ref_D.host_data(),
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
using unused_t = decltype(D);
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
cutlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
unused_t, // bias
unused_t, // aux
unused_t, // valpha
unused_t // vbeta
> epilogue_params;
epilogue_params.C = C;
epilogue_params.D = D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// get reference result
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// compare_reference
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0));
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0));
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0);
result.gflops = options.gflops(avg_runtime_s);
result.eff_bw = options.effective_bandwidth(avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD));
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
std::cout << " Effective bandwidth: " << result.eff_bw << " GB/s" << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif
return 0;
}

View File

@ -0,0 +1,36 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include_directories(
.
)
cutlass_example_add_executable(
63_hopper_gemm_with_weight_prefetch
63_hopper_gemm_with_weight_prefetch.cu
)

View File

@ -0,0 +1,82 @@
# GEMM with L2 weight prefetch
A non-persistent warp specialized GEMM directed at low latency inference.
The kernel can optionally prefetch a portion of weights (operand `A`) into L2 cache while the
rest of the warps are waiting on the previous kernel to finish writing and flush its memory.
An example of this is normalization or reduction kernels that are immediately followed by a GEMM.
It exposes two runtime parameters:
1. `overlap_ratio`: how early `griddepcontrol.launch_dependent_grids` is issued.
Default is `0.5`, meaning after approximately half of K tiles are loaded by DMA warps.
2. `prefetch_ratio`: what percentage of K tiles to prefetch.
Default is `-1.0`, meaning prefetching will stop as soon as other DMA warps are past
`griddepcontrol`.
It is highly recommended to auto-tune these parameters per GEMM and according to some end to end
runtime (either an entire transformer layer or multiple, but probably not the entire model.)
TMA loads use non-default cache hints: `A` (weights) are loaded with `EvictFirst`, and `B` (activation)
is loaded with `EvictLast`.
## Getting started
To use this kernel in your own target, add this directory to your includes, and include the
following headers from this example:
```cxx
#include "collective/dispatch_policy_extra.hpp"
#include "collective/builder.hpp"
#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp"
```
And then use either one of the new kernel schedules:
```cxx
// Without separate warps for A and B
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetch;
// With separate warps for A and B
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA;
```
The kernel with separate warps for A and B (
`KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA`)
is expected to be more performant than the other, especially since it allows the kernel to load
weights into shmem ahead of the `griddepcontrol`.
As for other GEMM parameters, Thread Block Cluster larger than 1 CTA are not yet supported, and
obviously the kernel layer implementation is warp specialized and uses the TMA, and other kernel
layers or collectives require reimplementation.
## Example
Using the example is mostly straightforward.
Just build, and run with your choice of `MNK`:
```bash
./63_hopper_gemm_with_weight_prefetch --m=8192 --n=1 --k=8192
```
You can also disable the overlap or try different overlap and prefetch ratios and see the
difference:
```bash
echo "Without overlap and prefetch"
./63_hopper_gemm_with_weight_prefetch --o=-1.0 --p=-1.0
echo "Overlap ratio of 0.5, best effort prefetch"
./63_hopper_gemm_with_weight_prefetch --o=0.5 --p=-1.0
echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
./63_hopper_gemm_with_weight_prefetch --o=0.8 --p=0.7
```
However, note that the example still runs a single GEMM, and most of the performance improvement
is expected in end to end applications.
## Limitations
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
memory barrier before issuing every single TMA load, and in many cases this will slow down
prefetching to the point of being almost ineffective.

View File

@ -0,0 +1,215 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "dispatch_policy_extra.hpp"
#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
namespace cutlass::gemm::collective {
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
template <
class ElementA,
class GmemLayoutATag,
int AlignmentA,
class ElementB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType
>
struct CollectiveBuilder<
arch::Sm90,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATag,
AlignmentA,
ElementB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
cute::enable_if_t<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
"Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
TagToStrideA_t<GmemLayoutATag>,
ElementB,
TagToStrideB_t<GmemLayoutBTag>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
};
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps
template <
class ElementA,
class GmemLayoutATag,
int AlignmentA,
class ElementB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType
>
struct CollectiveBuilder<
arch::Sm90,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATag,
AlignmentA,
ElementB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
cute::enable_if_t<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
"Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
ElementA, ElementB, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
TagToStrideA_t<GmemLayoutATag>,
ElementB,
TagToStrideB_t<GmemLayoutBTag>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
};
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,61 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
namespace cutlass::gemm {
// Standard non-persistent kernel with a single producer warp, and one prefetch warp.
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
// while the producer warp is waiting on griddepcontrol.
// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and
// according to prefetch ratio.
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { };
// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp.
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not
// wait on griddepcontrol and loads immediately.
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { };
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch
>
struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
};
} // namespace cutlass::gemm

View File

@ -0,0 +1,867 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/arch/grid_dependency_control.h"
#include "dispatch_policy_extra.hpp"
#include "../pipeline/prefetch_pipeline_sm90.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <
int Stages,
class ClusterShape,
class KernelSchedule,
class TileShape_,
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
static constexpr int PrefetchStages = 4;
static constexpr int PrefetchInitialStages = 1;
// This determines how much shmem we set aside for prefetch.
// We don't reuse anything loaded by prefetcher, so we can keep
// loading into the same place -- there will be a conflict when
// writing, but it doesn't affect performance as much as the doors
// that this opens.
static constexpr int PrefetchStagesActual = 1;
using PrefetcherPipeline = cutlass::PrefetchPipeline<PrefetchStages>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages);
static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages);
using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
cute::Int<size<0>(SmemLayoutA{})>{},
cute::Int<size<1>(SmemLayoutA{})>{},
cute::Int<PrefetchStagesActual>{})));
static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
// Defined outside the class where it's used, to work around MSVC issues
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _0> {
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<typename TiledMma::ValTypeA, prefetch_smem_size> smem_prefetch;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
PrefetcherPipelineStorage prefetcher_pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
uint32_t mma_promotion_interval = 4;
float overlap_ratio = 0.5;
float prefetch_ratio = -1.0;
};
// Device side kernel params
struct Params {
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy_A_sm90(
GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{}));
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy_B_sm90(
GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{}));
TMA_A tma_load_a;
TMA_B tma_load_b;
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
float overlap_ratio = 0.5;
float prefetch_ratio = -1.0;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{});
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{});
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
return {
tma_load_a,
tma_load_b,
transaction_bytes,
transaction_bytes_mk,
transaction_bytes_nk,
args.overlap_ratio,
args.prefetch_ratio
};
}
template<class ProblemShape>
static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
return false;
}
if (args.overlap_ratio > 1.0) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n");
return false;
}
if (args.prefetch_ratio > 1.0) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n");
return false;
}
return true;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytesMK =
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
static constexpr uint32_t TmaTransactionBytesNK =
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl);
}
template <
class TensorA, class TensorB,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load(
Params const& mainloop_params,
MainloopPipeline pipeline,
PrefetcherPipeline prefetcher_pipeline,
PipelineState smem_pipe_write,
TensorA const& gA_mkl,
TensorB const& gB_nkl,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
float overlap_ratio = mainloop_params.overlap_ratio;
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from cta_tma_a
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
// Applies the mapping from cta_tma_b
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// We have to wait on dependent grids because of B.
cutlass::arch::wait_on_dependent_grids();
// Signal prefetcher to stop
prefetcher_pipeline.producer_arrive();
bool launch_dep_grids = false;
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true;
cutlass::arch::launch_dependent_grids();
}
// Advance smem_pipe_write
++smem_pipe_write;
}
if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids();
}
}
}
template <
class TensorA,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load_MK(
Params const& mainloop_params,
MainloopPipeline pipeline,
PrefetcherPipeline prefetcher_pipeline,
PipelineState smem_pipe_write,
TensorA const& gA_mkl,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
float overlap_ratio = mainloop_params.overlap_ratio;
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
//
// Prepare the TMA loads for A
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
// Applies the mapping from cta_tma_a
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
// Don't wait on dependent grids when loading `A`, because
// we assume `A` (weights) are static.
bool launch_dep_grids = false;
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
++k_tile_iter;
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
launch_dep_grids = true;
cutlass::arch::launch_dependent_grids();
}
// Advance smem_pipe_write
++smem_pipe_write;
}
if (!disable_gdc && !launch_dep_grids) {
cutlass::arch::launch_dependent_grids();
}
}
}
template <
class TensorB,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load_NK(
Params const& mainloop_params,
MainloopPipeline pipeline,
PrefetcherPipeline prefetcher_pipeline,
PipelineState smem_pipe_write,
TensorB const& gB_nkl,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for B
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
// Applies the mapping from cta_tma_b
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_b = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// Ensure that the prefetched kernel does not touch
// unflushed global memory prior to this instruction
cutlass::arch::wait_on_dependent_grids();
// Signal prefetcher to stop
prefetcher_pipeline.producer_arrive();
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
template <
class TensorA,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
prefetch_MK(
Params const& mainloop_params,
PrefetcherPipeline prefetcher_pipeline,
PipelineState smem_pipe_write,
TensorA const& gA_mkl,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages);
Tensor sA = make_tensor(
make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
//
// Prepare the TMA loads for A
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
// Applies the mapping from cta_tma_a
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
uint32_t prefetcher_stage = 0;
uint32_t prefetcher_phase = 0;
CUTLASS_PRAGMA_NO_UNROLL
for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) {
if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) {
break;
}
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages);
using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);
int write_stage = 0;
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
++k_tile_iter;
++k_tile_iter;
prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase);
}
prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
class FrgTensorC
>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read,
FrgTensorC& accum,
int k_tile_count,
int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,117 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.f, beta = 0.f;
float overlap_ratio = 0.5f, prefetch_ratio = 0.5f;
int iterations = 1000;
int n = 64, m = 1280, k = 8192, l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f);
cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "63_hopper_gemm_with_weight_prefetch\n\n"
<< " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n"
<< " For more details please refer to the source file.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --p=<f32> Prefetch ratio\n"
<< " --o=<f32> Overlap ratio\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "63_hopper_gemm_with_weight_prefetch" <<
" --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
/// Compute effective bandwidth in GB/sec
double effective_bandwidth(
double runtime_s,
size_t bytes_a,
size_t bytes_b,
size_t bytes_c,
size_t bytes_d
) const
{
static double const kBytesPerGiB = double(1ull << 30);
double bytes_in =
(double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A
(double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B
(beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C
double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D
double gb_total = (bytes_in + bytes_out) / kBytesPerGiB;
return gb_total / runtime_s;
}
};

View File

@ -0,0 +1,561 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cute/tensor.hpp"
#include "../collective/dispatch_policy_extra.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
// GEMM + Prefetch for the A tensor + (optional) split DMA warps
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class TileScheduler_
>
class GemmUniversal<
ProblemShape_,
CollectiveMainloop_,
CollectiveEpilogue_,
TileScheduler_,
cute::enable_if_t<
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA> ||
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>
>
>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
static constexpr bool SplitWarps = cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>;
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
"TMA warp-specialized kernel does not support specializing the tile scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
// Kernel level shared memory storage
struct SharedStorage {
// Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union
union TensorStorage {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16, _1> {
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) PrefetcherPipelineStorage prefetcher;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 1;
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
auto problem_shape = args.problem_shape;
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
// swap M/N
get<0>(problem_shape) = get<1>(args.problem_shape);
get<1>(problem_shape) = get<0>(args.problem_shape);
}
return {
args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
};
}
static bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static
size_t
get_workspace_size(Arguments const& args) {
return 0;
}
static
cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter* cuda_adapter = nullptr) {
return Status::kSuccess;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3
get_grid_shape(Params const& params) {
auto cluster_shape = ClusterShape{};
auto tile_shape = TileShape{};
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
return TileScheduler::get_tiled_cta_shape_mnl(
problem_shape_MNKL, tile_shape, cluster_shape);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
using namespace cute;
using X = Underscore;
#if defined(__CUDA_ARCH_FEAT_SM90_ALL)
# define ENABLE_SM90_KERNEL_LEVEL 1
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(ENABLE_SM90_KERNEL_LEVEL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
enum class WarpGroupRole {
Producer = 0,
Consumer = 1,
};
// Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK.
// Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused.
// Both modes use Warp1 to prefetch.
enum class ProducerWarpRole {
Warp0 = 0,
PrefetchMK = 1,
Warp2 = 2,
UnusedWarp = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
if (warp_group_role == WarpGroupRole::Producer && (
producer_warp_role == ProducerWarpRole::Warp0 ||
producer_warp_role == ProducerWarpRole::Warp2)) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
}
if (warp_group_role == WarpGroupRole::Consumer) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
bool should_prefetch = params.mainloop.prefetch_ratio > 0;
using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline;
typename PrefetcherPipeline::Params prefetcher_pipeline_params;
prefetcher_pipeline_params.num_prefetchers = 1;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
prefetcher_pipeline_params.should_prefetch = should_prefetch;
prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk;
}
PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params);
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
}
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [&] () {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
// Non-prefetcher warps arrive and wait,
// Prefetcher warp can go ahead without waiting.
cute::cluster_arrive_relaxed();
if (warp_group_role != WarpGroupRole::Producer ||
producer_warp_role != ProducerWarpRole::PrefetchMK) {
cute::cluster_wait();
}
return [] () {};
}
else {
// __syncthreads() but only for non prefetcher warps
if (should_prefetch) {
// Use a named barrier to let the prefetcher warp start loading into the L2
// without waiting to sync with all other warps.
// All other warps need to sync because the mainloop pipeline init
// should be visible to all of them.
// Prefetcher has its own barriers, and the only warps it would need to sync
// with would be the DMA warps.
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
blockDim.x * blockDim.y * blockDim.z,
/*reserved_named_barriers_*/ 14);
// Prefetcher warp doesn't arrive on this barrier.
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
/*reserved_named_barriers_*/ 15);
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
__syncwarp();
prefetcher_arrive_barrier.arrive();
}
else if (warp_group_role == WarpGroupRole::Producer) {
prefetcher_arrive_barrier.arrive_and_wait();
cluster_arrive_barrier.arrive_and_wait();
}
else {
prefetcher_arrive_barrier.arrive();
cluster_arrive_barrier.arrive_and_wait();
}
} else {
__syncthreads();
}
return [] () {};
}
} ();
// Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TiledMma tiled_mma;
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
// Compute m_coord, n_coord, and l_coord with their post-tiled shapes
auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Get pipeline iterators and increments from tensor shapes
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
auto k_tile_count = size<3>(gA_mkl);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer) {
if (producer_warp_role == ProducerWarpRole::Warp0) {
if constexpr(SplitWarps) {
collective_mainloop.load_NK(
params.mainloop,
mainloop_pipeline,
prefetcher_pipeline,
mainloop_pipe_producer_state,
gB_nkl,
blk_coord,
k_tile_iter, k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
}
else {
collective_mainloop.load(
params.mainloop,
mainloop_pipeline,
prefetcher_pipeline,
mainloop_pipe_producer_state,
gA_mkl, gB_nkl,
blk_coord,
k_tile_iter, k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
}
// Update starting mainloop pipeline state for the pipeline drain
mainloop_pipe_producer_state.advance(k_tile_count);
// Make sure mainloop consumer has been waited upon before issuing epilogue load
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
if (collective_epilogue.is_producer_load_needed()) {
// Ensure warp is converged before issuing epilogue loads
__syncwarp();
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline,
epi_load_pipe_producer_state,
problem_shape_MNKL,
blk_shape,
blk_coord,
tiled_mma,
lane_idx,
shared_storage.tensors.epilogue
);
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
}
}
else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) {
collective_mainloop.load_MK(
params.mainloop,
mainloop_pipeline,
prefetcher_pipeline,
mainloop_pipe_producer_state,
gA_mkl,
blk_coord,
k_tile_iter, k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
// Update starting mainloop pipeline state for the pipeline drain
mainloop_pipe_producer_state.advance(k_tile_count);
// Make sure mainloop consumer has been waited upon before issuing epilogue load
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) {
collective_mainloop.prefetch_MK(
params.mainloop,
prefetcher_pipeline,
mainloop_pipe_producer_state,
gA_mkl,
blk_coord,
k_tile_iter, k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
}
}
else if (warp_group_role == WarpGroupRole::Consumer) {
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
collective_mainloop.mma(
mainloop_pipeline,
mainloop_pipe_consumer_state,
accumulators,
k_tile_count,
warp_group_thread_idx,
shared_storage.tensors.mainloop,
params.mainloop
);
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(
mainloop_pipeline,
mainloop_pipe_consumer_state,
k_tile_count
);
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline,
epi_load_pipe_consumer_state,
epi_store_pipeline,
epi_store_pipe_producer_state,
problem_shape_MNKL,
blk_shape,
blk_coord,
accumulators,
tiled_mma,
warp_group_thread_idx,
shared_storage.tensors.epilogue
);
collective_epilogue.store_tail(
epi_load_pipeline,
epi_load_pipe_consumer_state_next,
epi_store_pipeline,
epi_store_pipe_producer_state_next
);
}
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@ -0,0 +1,161 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/barrier.h"
#include "cute/container/array.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace detail {
// MSVC work-around
template <int Stages>
struct PrefetcherPipelineSharedStorage {
using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier;
using Barrier = cutlass::arch::ClusterBarrier;
TransactionBarrier tma_barrier[Stages];
Barrier producer_ready_barrier;
};
} // end namespace detail
using namespace cute;
// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction
// barrier providing control over the number of concurrent outstanding TMA loads.
// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset.
// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks
// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads.
template <int Stages_>
class PrefetchPipeline {
public :
static constexpr uint32_t Stages = Stages_;
using SharedStorage = detail::PrefetcherPipelineSharedStorage<Stages>;
using TransactionBarrier = typename SharedStorage::TransactionBarrier;
using Barrier = typename SharedStorage::Barrier;
using PrefetcherBarrierType = typename TransactionBarrier::ValueType;
struct Params {
uint32_t transaction_bytes = 0;
uint32_t num_prefetchers = 1;
bool should_prefetch = false;
};
// Constructor
CUTLASS_DEVICE
PrefetchPipeline(SharedStorage& storage, Params params)
: params_(params)
, tma_barrier_ptr_(&storage.tma_barrier[0])
, producer_ready_barrier_ptr_(&storage.producer_ready_barrier) {
int lane_predicate = cute::elect_one_sync();
if (params.should_prefetch && lane_predicate) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Stages; ++i) {
tma_barrier_ptr_[i].init(params.num_prefetchers);
}
producer_ready_barrier_ptr_[0].init(1);
}
}
CUTLASS_DEVICE
void producer_arrive() {
if (params_.should_prefetch) {
producer_ready_barrier_ptr_[0].arrive();
}
}
CUTLASS_DEVICE
bool have_producers_arrived() {
if (params_.should_prefetch) {
uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0);
auto barrier_status = static_cast<BarrierStatus>(barrier_status_);
if (barrier_status == BarrierStatus::WaitDone) {
return true; // exit prefetcher loop
}
return false;
}
return true;
}
CUTLASS_DEVICE
void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) {
if (params_.should_prefetch) {
if (should_wait) {
tma_barrier_ptr_[stage].wait(phase ^ 1);
}
tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes);
}
}
CUTLASS_DEVICE
void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) {
if (params_.should_prefetch) {
stage++;
if (stage == Stages) {
stage = 0;
phase ^= 1;
}
}
}
CUTLASS_DEVICE
void prefetcher_tail(uint32_t stage, uint32_t phase) {
if (params_.should_prefetch) {
// Wait on any already-issued loads
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < stage; ++i) {
tma_barrier_ptr_[i].wait(phase);
}
}
}
CUTLASS_DEVICE
PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) {
return reinterpret_cast<PrefetcherBarrierType*>(&tma_barrier_ptr_[stage]);
}
private :
TransactionBarrier* tma_barrier_ptr_ = nullptr;
Barrier* producer_ready_barrier_ptr_ = nullptr;
Params params_;
};
} // end namespace cutlass

View File

@ -140,6 +140,9 @@ foreach(EXAMPLE
57_hopper_grouped_gemm
58_ada_fp8_gemm
59_ampere_gather_scatter_conv
61_hopper_gemm_with_topk_and_softmax
62_hopper_sparse_gemm
63_hopper_gemm_with_weight_prefetch
)
add_subdirectory(${EXAMPLE})

View File

@ -186,8 +186,8 @@ int main(int argc, char** argv)
return -1;
}
// Equivalent check to the above
if (not weakly_compatible(block_shape, tensor_shape)) {
std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl;
if (not evenly_divides(tensor_shape, block_shape)) {
std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl;
return -1;
}