CUTLASS 3.6.0 (#1850)
* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -42,7 +42,8 @@
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
@ -56,6 +57,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
@ -657,7 +659,9 @@ struct Testbed {
|
||||
}
|
||||
|
||||
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
|
||||
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n();
|
||||
int64_t bytes = cutlass::bits_to_bytes(
|
||||
(cutlass::sizeof_bits<ElementD>::value * 2 + cutlass::sizeof_bits<ElementSoftmax>::value) *
|
||||
options.problem_size.m() * options.problem_size.n());
|
||||
|
||||
double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9);
|
||||
double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30);
|
||||
|
||||
@ -303,14 +303,14 @@ bool initialize_block(
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
|
||||
@ -111,7 +111,7 @@ public:
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
struct PipelineStorage : cute::aligned_struct<16, _2> {
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ struct PermuteTraits {};
|
||||
using X = Underscore;
|
||||
|
||||
// Reshape a rank-2 shape into a multidimensional shape.
|
||||
// Input:
|
||||
// Input:
|
||||
// shape = (A, B, ...)
|
||||
// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...)
|
||||
// Output:
|
||||
@ -76,12 +76,12 @@ reshape(Shape const& shape, TargetShape const& target_shape)
|
||||
// - sub-modes corresponding to the implied multidimensional shape of the source tensor
|
||||
// - strides accounting for the permutation operation being performed
|
||||
template<class Permute, bool Transpose, class Shape, class Stride>
|
||||
constexpr auto
|
||||
constexpr auto
|
||||
make_permute_layout(Layout<Shape,Stride> const& layout) {
|
||||
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
if constexpr (Transpose) {
|
||||
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
return select<1,0,2>(make_permute_layout<Permute, false>(select<1,0,2>(layout)));
|
||||
}
|
||||
else {
|
||||
@ -129,23 +129,24 @@ inverse(Permutation const & perm) {
|
||||
template<class T>
|
||||
using inverse_t = decltype(inverse(T{}));
|
||||
|
||||
// Given a rank-2 layout of tensor that is assumed to have been permuted,
|
||||
// Given a rank-2 layout of tensor that is assumed to have been permuted,
|
||||
// compute the original rank-2 layout of the tensor prior to the permutation.
|
||||
// This is needed to form the correct input to the standalone permutation kernel.
|
||||
// This is needed to form the correct input to the standalone permutation kernel.
|
||||
template<class Permute, bool Transpose, class Shape, class Stride>
|
||||
constexpr auto
|
||||
constexpr auto
|
||||
make_original_layout(Layout<Shape,Stride> const& layout) {
|
||||
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
if constexpr (Transpose) {
|
||||
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
return select<1,0,2>(make_original_layout<Permute, false>(select<1,0,2>(layout)));
|
||||
}
|
||||
else {
|
||||
using ShapeProfile = typename PermuteTraits<Permute>::ShapeProfile;
|
||||
auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{}));
|
||||
using IndexOrder = typename PermuteTraits<Permute>::IndexOrder;
|
||||
auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get<i>(re_shape); });
|
||||
using OrigOrder = conditional_t<cutlass::gemm::detail::is_major<0,Stride>(), seq<0,1,2>, seq<1,0,2>>;
|
||||
auto orig_shape = select(flatten(reshape(layout.shape(), ShapeProfile{})), IndexOrder{});
|
||||
// print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n");
|
||||
// print("Original shape: "); print(orig_shape); print("\n");
|
||||
return make_ordered_layout(product_each(orig_shape), OrigOrder{});
|
||||
@ -202,7 +203,7 @@ struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D>>
|
||||
};
|
||||
|
||||
template<int D>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D>>
|
||||
{
|
||||
static constexpr bool kBatched = true;
|
||||
using ShapeProfile = Shape<Shape<X,Int<D>>, Shape<X>, Shape<X>>;
|
||||
@ -222,7 +223,7 @@ struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D>>
|
||||
};
|
||||
|
||||
template<int D>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
|
||||
struct PermuteTraits<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D>>
|
||||
{
|
||||
static constexpr bool kBatched = true;
|
||||
using ShapeProfile = Shape<Shape<X>, Shape<X,Int<D>>, Shape<X>>;
|
||||
|
||||
701
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
Normal file
701
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
Normal file
@ -0,0 +1,701 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example shows how to perform INT4 x FP8 GEMM and scale up the INT4 weight during dequantization. It uses a look-up table to avoid the multiplications
|
||||
between INT4 and FP8. To trigger this method, use cutlass::Array<ElementScale, 8> as the scale type in the collective's arguments.
|
||||
|
||||
However, this algorithm requires changes to the encoding of INT4 weights and scale factors. These changes must happen before launching the GEMM. See the helper functions
|
||||
`unify_quant_encoding`, `initialize_packed_scale`, and header `fp8_packed_scale.hpp` for details.
|
||||
|
||||
In a nutshell, the positive values of INT4 weights need to be encoded in the same way as negative values except for the sign bit. For each scale factor,
|
||||
8 negative results (-8 x scale, -7 x scale, ... -1 x scale) are packed together, forming a cutlass::Array<ElementScale, 8> value.
|
||||
|
||||
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
|
||||
A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
|
||||
as illustrated in this example.
|
||||
|
||||
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
|
||||
|
||||
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
|
||||
|
||||
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
|
||||
|
||||
If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
|
||||
|
||||
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) Only supports INT4 x { FP8, INT8, UINT8 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported.
|
||||
2) The INT4 weights and scale factors have additional encoding requirements.
|
||||
3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
4) The scales must have the same layout and groupsize.
|
||||
5) The groupsize must be greater or equal to the tile shape k.
|
||||
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Optimizing suggestions:
|
||||
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
|
||||
|
||||
Examples:
|
||||
|
||||
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
|
||||
|
||||
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
|
||||
matrix (group size is the same as the gemm k dimension).
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "unfused_weight_dequantize.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale; // only for verify
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C matrix.
|
||||
// We can enable this if beta == 0 by changing ElementC to void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
|
||||
>::CollectiveOp;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8> >, LayoutB_Transpose, AlignmentB,
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
using StrideC = typename GemmKernelScaleOnly::StrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::StrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideC_ref stride_C_ref;
|
||||
StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementB> block_B_modified;
|
||||
cutlass::DeviceAllocation<ElementA> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8>> block_scale_packed;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 10;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
int l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("g", g);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "55_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
bool unify_quant_encoding(
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t>& block_out) {
|
||||
|
||||
using StorageType = cutlass::int4b_t::Storage;
|
||||
|
||||
if (block_in.size() != block_out.size()) {
|
||||
std::cerr << "block_in and block_out must have same size.\n";
|
||||
return false;
|
||||
}
|
||||
constexpr int pack = sizeof_bits_v<StorageType> / 4;
|
||||
std::vector<StorageType> data(block_in.size() / pack);
|
||||
cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack);
|
||||
|
||||
for (auto&& d : data) {
|
||||
StorageType out = 0;
|
||||
StorageType mask = 0x0f;
|
||||
for (int i = 0; i < pack; ++i) {
|
||||
cutlass::int4b_t curr;
|
||||
curr.storage = (d >> (i * 4)) & 0x0f;
|
||||
switch (curr) {
|
||||
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
||||
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
||||
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
||||
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
||||
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
||||
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
||||
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
||||
default: break;
|
||||
}
|
||||
out |= (curr.storage << (4 * i)) & mask;
|
||||
mask <<= 4;
|
||||
}
|
||||
d = out;
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
Options const& options) {
|
||||
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
|
||||
float const max_dequant_val = 4.f;
|
||||
float const min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool initialize_packed_scale(
|
||||
cutlass::DeviceAllocation<ElementScale> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8> > & block_out) {
|
||||
|
||||
std::vector<ElementScale> data_in(block_in.size());
|
||||
std::vector<cutlass::Array<ElementScale, 8> > data_out(block_in.size());
|
||||
try {
|
||||
block_in.copy_to_host(data_in.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < block_in.size(); ++i)
|
||||
{
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
// std::cout << data_in[i] << ":" << std::hex << static_cast<uint16_t>(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast<uint16_t>((-data_in[i]).storage) << std::endl;
|
||||
}
|
||||
try {
|
||||
block_out.copy_from_host(data_out.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_zero(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
Options const& options) {
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_modified.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
block_scale.reset(scale_k * options.l * options.n);
|
||||
block_scale_packed.reset(scale_k * options.l * options.n);
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
unify_quant_encoding(block_B, block_B_modified);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_packed_scale(block_scale, block_scale_packed);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto layout_B = make_layout(shape_b, stride_B);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Args>
|
||||
Args args_from_options(Options const& options)
|
||||
{
|
||||
// Swap the A and B tensors, as well as problem shapes here.
|
||||
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// In this example, we use the GPU default kernels as a reference (unfused scale).
|
||||
// This avoids numerical differences due to different accumulation order.
|
||||
|
||||
// Again, due to numerical differences, we must use fast acc here when the mma type is
|
||||
// FP8 as the fused implementation only supports fast acc at the moment.
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
}
|
||||
run<GemmScaleOnly>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -53,14 +53,18 @@
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}.
|
||||
2) The narrow type must always be in K-major format.
|
||||
3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
4) The scales and the zeros must have the same layout and groupsize.
|
||||
1) The narrow type must always be in K-major format.
|
||||
2) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
3) The scales and the zeros must have the same layout and groupsize.
|
||||
4) The groupsize must be greater or equal to tile shape k.
|
||||
5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format.
|
||||
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Optimizing suggestions:
|
||||
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
|
||||
2) Try avoid using scale or zero mode cause the computations will be the bottleneck.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -94,11 +98,8 @@
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "unfused_weight_dequantize.hpp"
|
||||
@ -117,8 +118,8 @@ enum GemmMode {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
using MmaType = cutlass::half_t;
|
||||
using QuantType = cutlass::float_e4m3_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
@ -154,8 +155,8 @@ using ElementAccumulator = float; // E
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_256,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
@ -268,14 +269,14 @@ using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::HostTensor<MmaType, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<QuantType, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<MmaType, LayoutB> tensor_B_dq;
|
||||
cutlass::HostTensor<ElementScale, LayoutScale> tensor_scale;
|
||||
cutlass::HostTensor<ElementZero, LayoutScale> tensor_zero;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_D;
|
||||
cutlass::HostTensor<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput, LayoutD> tensor_ref_D;
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementA> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
@ -290,7 +291,7 @@ struct Options {
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int iterations = 10;
|
||||
int mode = 2;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
@ -368,9 +369,9 @@ struct Result
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class Layout>
|
||||
template <class Element>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
@ -393,34 +394,35 @@ bool initialize_tensor(
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element, typename Layout>
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element, class Layout>
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
const Options &options) {
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
Options const& options) {
|
||||
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
cutlass::reference::host::TensorFill(view, Element(1.0f));
|
||||
std::vector<Element> stage(block.size(), Element(1.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
else {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
|
||||
@ -430,32 +432,33 @@ bool initialize_scale(
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element, class Layout>
|
||||
template <class Element>
|
||||
bool initialize_zero(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
const Options &options) {
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
Options const& options) {
|
||||
|
||||
if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2.0f, -2.0f);
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
cutlass::reference::host::TensorFill(view, Element(0.0f));
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
const int scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -469,27 +472,21 @@ void initialize(const Options &options) {
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_B_dq.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
tensor_scale.resize({scale_k * options.l, options.n});
|
||||
tensor_zero.resize({scale_k * options.l, options.n});
|
||||
block_scale.reset(scale_k * options.l * options.n);
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_quant_tensor(tensor_B.host_view(), seed + 2021);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2020);
|
||||
initialize_scale(tensor_scale.host_view(), options);
|
||||
initialize_zero(tensor_zero.host_view(), options);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_scale.sync_device();
|
||||
tensor_zero.sync_device();
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto layout_B = make_layout(shape_b, stride_B);
|
||||
|
||||
@ -498,37 +495,36 @@ void initialize(const Options &options) {
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g);
|
||||
tensor_B_dq.sync_host();
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Args>
|
||||
Args args_from_options(const Options &options)
|
||||
Args args_from_options(Options const& options)
|
||||
{
|
||||
// Swap the A and B tensors, as well as problem shapes here.
|
||||
if (options.mode == GemmMode::ConvertOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleOnly) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D}
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
} else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
@ -542,7 +538,7 @@ bool verify(const Options &options) {
|
||||
//
|
||||
|
||||
// In this example, we use the GPU default kernels as a reference (unfused scale)
|
||||
// This is to avoid numerical differences from different accumulation order.
|
||||
// This avoids numerical differences due to different accumulation order.
|
||||
|
||||
// Again, due to numerical differences, we must use fast acc here when the mma type is
|
||||
// FP8 as the fused implementation only supports fast acc at the moment.
|
||||
@ -581,8 +577,8 @@ bool verify(const Options &options) {
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B},
|
||||
{{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref}
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
@ -594,11 +590,9 @@ bool verify(const Options &options) {
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
tensor_ref_D.sync_host();
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor);
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
@ -55,5 +55,16 @@ cutlass_example_add_executable(
|
||||
TEST_SCALE_ZERO_GROUPED
|
||||
TEST_SCALE_RESIDUE
|
||||
TEST_SCALE_ZERO_RESIDUE
|
||||
TEST_ALPHA_BETA
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
55_hopper_int4_fp8_gemm
|
||||
55_hopper_int4_fp8_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
TEST_SCALE_RESIDUE
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
@ -11,6 +11,8 @@ This first version only supports mixed type GEMMs using TMA.
|
||||
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
|
||||
|
||||
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
|
||||
|
||||
We are currently optimizing the following cases:
|
||||
1. Memory bound cases for all types
|
||||
|
||||
|
||||
132
examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp
Normal file
132
examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp
Normal file
@ -0,0 +1,132 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdint>
|
||||
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
template<typename T>
|
||||
class packed_scale_t {
|
||||
public:
|
||||
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
||||
cute::is_same_v<T, cutlass::uint8_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
||||
"only 8 bit arithmetic types are supported.");
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit packed_scale_t(T val) {
|
||||
if constexpr (!cute::is_unsigned_v<T>) {
|
||||
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
||||
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
||||
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
||||
}
|
||||
else {
|
||||
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
||||
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
||||
}
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
packed_scale_t() = default;
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(packed_scale_t const& rhs) const {
|
||||
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(packed_scale_t const& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() + rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() - rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() * rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() / rhs.get());
|
||||
}
|
||||
|
||||
private:
|
||||
using Storage = uint32_t;
|
||||
using Stage = uint8_t;
|
||||
|
||||
Storage storage[2] {};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Storage pack4(T c1, T c2, T c3, T c4) {
|
||||
Storage result = 0;
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
||||
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
||||
return result;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get() const {
|
||||
auto stage = static_cast<Stage>(storage[0] >> 8);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get(int idx) const {
|
||||
Stage stage;
|
||||
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
||||
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -32,7 +32,7 @@
|
||||
/*! \file
|
||||
\brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
|
||||
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
|
||||
This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA
|
||||
warp-specialized cooperative kernel.
|
||||
The new feature showcased in this example is on-the-fly modification of TMA descriptors
|
||||
to move between batches (represented by l).
|
||||
@ -547,3 +547,4 @@ int main(int argc, char const **args) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -91,9 +91,9 @@
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
|
||||
@ -26,6 +26,8 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (NOT MSVC)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
59_ampere_gather_scatter_conv
|
||||
ampere_gather_scatter_conv.cu
|
||||
@ -34,3 +36,5 @@ cutlass_example_add_executable(
|
||||
if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND)
|
||||
target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
@ -0,0 +1,534 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM + Top-K + Softmax fusion
|
||||
|
||||
This example illustrates how to use the LinCombTopKSoftmaxCol EVT node to fuse
|
||||
Top-K and Softmax into the GEMM epilogue, with certain assumptions made.
|
||||
|
||||
Those assumptions are as:
|
||||
1. Fusion is over the N dimension.
|
||||
2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be
|
||||
compiled to support both.)
|
||||
3. The GEMM tile shape along N is greater than or equal to problem size
|
||||
along N.
|
||||
|
||||
|
||||
The example runs the fused GEMM kernel, along with a standard unfused host reference, and
|
||||
manually performs Top-K and softmax, and compares the error between tensors.
|
||||
|
||||
Note that some numerical error (smaller than 1e-5) is to be expected, but this is true
|
||||
in most efficient reduction kernels, because floating point addition is not necessarily
|
||||
associative.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
static constexpr int TopK = 2;
|
||||
static constexpr bool EnableTopKSoftmax = TopK > 1;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentC = 1;
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutD = cutlass::layout::RowMajor; // Layout type for output
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of output in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
// Top-K + Softmax fusion operation
|
||||
using FusionOperation = std::conditional_t<EnableTopKSoftmax,
|
||||
typename cutlass::epilogue::fusion::LinCombTopKSoftmaxCol<TopK, ElementD, ElementCompute>,
|
||||
typename cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute>
|
||||
>;
|
||||
|
||||
// The fusion op only allows for epilogue tiles matching the mainloop tile.
|
||||
using EpilogueTileType = decltype(cute::take<0,2>(TileShape{}));
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
int iterations = 1000;
|
||||
int m = 16, n = 8, k = 64, l = 1;
|
||||
double eps = 1e-5;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("eps", eps);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "61_hopper_gemm_with_topk_and_softmax\n\n"
|
||||
<< " Hopper FP8 GEMM with Top-K and softmax fusion.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --eps=<float> Threshold of numerical verification. Default: 1e-5.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "61_hopper_gemm_with_topk_and_softmax" << " --m=16 --n=8 --k=1024 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
float alpha() const {
|
||||
return 1.f / static_cast<float>(k);
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, /* max = */ 1, /* min = */ -1, /* bits = */ 2);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2023);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_D.sync_device();
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{options.alpha(), 0.f}, // alpha, beta
|
||||
nullptr, stride_D,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
unused_t,
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha();
|
||||
epilogue_params.beta = 0.f;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
if constexpr (EnableTopKSoftmax) {
|
||||
// top-K + softmax
|
||||
for (int i = 0; i < options.m; ++i) {
|
||||
|
||||
// Find Top-K
|
||||
cutlass::Array<ElementAccumulator, TopK> top_k;
|
||||
top_k.fill(-cutlass::platform::numeric_limits<ElementCompute>::infinity());
|
||||
for (int j = 0; j < options.n; ++j) {
|
||||
auto val = static_cast<ElementAccumulator>(tensor_ref_D.host_view().ref().at({i, j}));
|
||||
for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) {
|
||||
if (val > top_k[top_k_idx]) {
|
||||
// Shift down
|
||||
for (int l = TopK - 1; l > top_k_idx; --l) {
|
||||
top_k[l] = top_k[l - 1];
|
||||
}
|
||||
top_k[top_k_idx] = val;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This formulation of top-K + softmax only works when it is
|
||||
// guaranteed that none of the top-K elements are repeated!
|
||||
// If this is the case, the device kernel can also make mistakes, because
|
||||
// A. Once the top-K values are reduced, and the operation is being applied,
|
||||
// there is no way to tell repeated elements apart, so none are masked.
|
||||
// B. The softmax sum of exps will be incorrect (because the repeated elements
|
||||
// are not repeated in it.)
|
||||
|
||||
ElementAccumulator max = top_k[0];
|
||||
ElementAccumulator sum = ElementAccumulator(0.f);
|
||||
for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) {
|
||||
sum = sum + cutlass::fast_exp(top_k[top_k_idx] - max);
|
||||
}
|
||||
|
||||
for (int j=0; j < options.n; ++j) {
|
||||
auto val = tensor_ref_D.host_view().ref().at({i, j});
|
||||
if (val < top_k[TopK - 1]) {
|
||||
tensor_ref_D.host_view().ref().at({i, j}) = static_cast<ElementD>(0.f);
|
||||
} else {
|
||||
// Softmax
|
||||
auto softmax_val = cutlass::fast_exp(val - max) / sum;
|
||||
tensor_ref_D.host_view().ref().at({i, j}) = static_cast<ElementD>(softmax_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
|
||||
double err = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_D.host_view(),
|
||||
tensor_ref_D.host_view());
|
||||
bool passed = err < options.eps;
|
||||
|
||||
if (options.m <= 32 && options.n <= 32) {
|
||||
std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n";
|
||||
std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options) {
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
32
examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt
Normal file
32
examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
61_hopper_gemm_with_topk_and_softmax
|
||||
61_hopper_gemm_with_topk_and_softmax.cu
|
||||
)
|
||||
596
examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu
Normal file
596
examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu
Normal file
@ -0,0 +1,596 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper Sparse GEMM example.
|
||||
|
||||
This example demonstrates how to construct and run a structured sparse GEMM kernel
|
||||
on NVIDIA Hopper architecture.
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel
|
||||
using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
|
||||
// Sparse kernel setup
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementA, LayoutTagA, AlignmentA,
|
||||
ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference (dense) kernel setup
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShapeRef, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutTagA, AlignmentA,
|
||||
ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShapeRef, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
// Layouts
|
||||
using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
// Layouts for reference (non-sparse) tensors
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
|
||||
using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
// Offline compressor kernel
|
||||
using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig>;
|
||||
|
||||
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape,
|
||||
ElementA,
|
||||
LayoutTagA,
|
||||
SparseConfig,
|
||||
cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ProblemShape problem_shape;
|
||||
|
||||
StrideA stride_A;
|
||||
StrideA stride_A_compressed;
|
||||
StrideE stride_E;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
|
||||
LayoutA layout_A;
|
||||
LayoutE layout_E;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A_compressed;
|
||||
cutlass::DeviceAllocation<typename Gemm::CollectiveMainloop::ElementE> block_E;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D_ref;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k, l;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(5120), n(4096), k(16384), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "62_hopper_sparse_gemm\n\n"
|
||||
<< " Hopper Sparse GEMM example.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent of the GEMM (batch size)\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
bool sparsify_and_compress()
|
||||
{
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
block_A_compressed.reset(M * KC * L);
|
||||
block_E.reset(ME * KE * L);
|
||||
|
||||
stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L));
|
||||
stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L));
|
||||
|
||||
// Random sparsification is performed on host
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
|
||||
compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast<int>(seed + 2024));
|
||||
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
typename Compressor::Arguments arguments {
|
||||
problem_shape,
|
||||
{ block_A.get(),
|
||||
stride_A,
|
||||
block_A_compressed.get(),
|
||||
block_E.get() },
|
||||
{hw_info} };
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
bool initialize(Options const& options) {
|
||||
|
||||
problem_shape = make_tuple(options.m, options.n, options.k, options.l);
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
// Allocate memory for tensors
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(N * K * L);
|
||||
block_C.reset(M * N * L);
|
||||
block_D.reset(M * N * L);
|
||||
block_D_ref.reset(M * N * L);
|
||||
|
||||
// Fill input tensors with data
|
||||
initialize_block(block_A, seed + 2021);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2023);
|
||||
|
||||
// Replace 0 in A with 1 to avoid metadata changes
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size());
|
||||
for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0);
|
||||
cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size());
|
||||
|
||||
if (!sparsify_and_compress()) {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Build the compressed/metadata layouts
|
||||
layout_A = SparseConfig::fill_layoutA(problem_shape);
|
||||
layout_E = SparseConfig::fill_layoutE(problem_shape);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments make_args(Options const& options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_shape,
|
||||
{ block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E },
|
||||
{ { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) },
|
||||
block_C.get(), stride_C, block_D.get(), stride_D }
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
typename GemmRef::Arguments make_args_ref(Options const& options)
|
||||
{
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_shape,
|
||||
{ block_A.get(), stride_A, block_B.get(), stride_B },
|
||||
{ { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) },
|
||||
block_C.get(), stride_C, block_D_ref.get(), stride_D }
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template<class Engine, class Layout>
|
||||
void print_device_tensor(cute::Tensor<Engine, Layout> const& t)
|
||||
{
|
||||
// Assumes size = cosize, i.e. compact tensor
|
||||
std::vector<typename Engine::value_type> data_host(t.size());
|
||||
cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size());
|
||||
auto t_host = cute::make_tensor(data_host.data(), t.layout());
|
||||
cute::print_tensor(t_host);
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size());
|
||||
|
||||
#if 0
|
||||
if (!passed) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
CompressorUtility compressor_utility(problem_shape, stride_A);
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A));
|
||||
cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed));
|
||||
cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E));
|
||||
cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast<CollectiveMainloop::ElementEMmaSparsity>(layout_E)));
|
||||
cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B));
|
||||
cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C));
|
||||
cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D));
|
||||
cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D));
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
template<typename Gemm>
|
||||
struct Runner
|
||||
{
|
||||
using Arguments = typename Gemm::Arguments;
|
||||
|
||||
Runner(Arguments args): arguments(args) {
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
workspace.reset(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
}
|
||||
|
||||
void run() {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
|
||||
void benchmark(Options const& options) {
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
run();
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
double avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double gflops = options.gflops(avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << gflops << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Gemm gemm;
|
||||
Arguments arguments;
|
||||
cutlass::device_memory::allocation<uint8_t> workspace;
|
||||
};
|
||||
|
||||
/// Execute the example (verification and timing)
|
||||
void run(Options &options) {
|
||||
bool init = initialize(options);
|
||||
if (!init) {
|
||||
std::cout << "Initialization failure" << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
Runner<Gemm> gemm(make_args(options));
|
||||
Runner<GemmRef> gemm_ref(make_args_ref(options));
|
||||
|
||||
gemm.run();
|
||||
gemm_ref.run();
|
||||
|
||||
bool passed = verify(options);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!passed) {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
std::cout << "Sparse GEMM:" << std::endl;
|
||||
gemm.benchmark(options);
|
||||
|
||||
std::cout << "Dense GEMM:" << std::endl;
|
||||
gemm_ref.benchmark(options);
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) {
|
||||
std::cerr << "This example requires CUDA 12.2 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED)
|
||||
run(options);
|
||||
#endif
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
36
examples/62_hopper_sparse_gemm/CMakeLists.txt
Normal file
36
examples/62_hopper_sparse_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Sparse kernel in this example triggers an ICE in gcc 7.5
|
||||
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
|
||||
cutlass_example_add_executable(
|
||||
62_hopper_sparse_gemm
|
||||
62_hopper_sparse_gemm.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,500 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper FP8 GEMM + L2 Weight Prefetch
|
||||
|
||||
This example implements a non-persistent warp-specialized GEMM kernel for the Hopper
|
||||
architecture with programmatic dependent launch (PDL) enabling prefetching weights into
|
||||
L2 cache.
|
||||
|
||||
For more information about dependent launch refer to the CUDA programming guide:
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization
|
||||
|
||||
In some cases, PDL can result in a window where a previous kernel is not actively utilizing
|
||||
DRAM, and the next kernel sits idle until the previous finishes. During this window, the next
|
||||
kernel can begin loading a non-dependent operand (i.e. weights in a linear projection are
|
||||
typically static) and cache it in L2.
|
||||
|
||||
The kernel and collective mainloop assume operand `A` corresponds to weights and operand `B`
|
||||
corresponds to activations (so we can have very small batch/token count).
|
||||
After initialization, the prefetch warp starts loading K tiles of `A` into an unused portion
|
||||
of shared memory, and loads up to half of all K tiles that the same CTA would eventually load.
|
||||
The exact number of K tiles loaded is determined by `args.mainloop.prefetch_ratio` \in
|
||||
[0.0, 1.0]. Smaller values result in less prefetching, and larger values result in more.
|
||||
Negative values result in a "best-effort" prefetch, meaning prefetcher will stop issuing weight
|
||||
loads as soon as the activation DMA warp starts loading (as soon as it is signaled that the
|
||||
previous kernel has flushed its memory.)
|
||||
|
||||
The DMA warp responsible for loading `A` will also begin loading K tiles until it fills up
|
||||
the available shared memory.
|
||||
The DMA warp responsible for loading `B` will wait until activations are flushed to global
|
||||
memory by the preceding kernel.
|
||||
|
||||
Another mainloop parameter, `args.mainloop.overlap_ratio` \in [0.0, 1.0] determines how early
|
||||
the next kernel (the one doing the prefetch) is launched. Smaller values result in greater
|
||||
overlap, and larger values result in smaller overlap. Negative values disable PDL completely,
|
||||
meaning there will be no overlap. This will make prefetch ineffective.
|
||||
|
||||
These two runtime parameters should be tuned per problem size and GEMM config combination, and
|
||||
if feasible, per-operation in an entire layer or model.
|
||||
|
||||
NOTE: you must build this target with the following flag to enable Grid Dependency Control
|
||||
instructions (GDC) in CUTLASS:
|
||||
- CUTLASS_ENABLE_GDC_FOR_SM90
|
||||
|
||||
To lock persistence mode, power (350W), clocks (1005MHz) for evaluation (assumes device 0 and H100)
|
||||
|
||||
$ sudo nvidia-smi -pm 1 -i 0
|
||||
|
||||
$ sudo nvidia-smi -i 0 -pl 350
|
||||
|
||||
$ sudo nvidia-smi -i 0 -lgc 1005
|
||||
|
||||
Example:
|
||||
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1
|
||||
|
||||
$ cd examples/63_hopper_gemm_with_weight_prefetch
|
||||
|
||||
$ make
|
||||
|
||||
$ ./63_hopper_gemm_with_weight_prefetch --p=0.5 --o=0.5
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
|
||||
#include "collective/dispatch_policy_extra.hpp"
|
||||
#include "collective/builder.hpp"
|
||||
#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "gemm_with_weight_prefetch_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size
|
||||
// Cluster_N > 1 is not supported yet.
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
double eff_bw;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
double eff_bw = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), eff_bw(eff_bw), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2024);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = scalar_alpha.device_data();
|
||||
fusion_args.beta_ptr = scalar_beta.device_data();
|
||||
|
||||
arguments.mainloop.overlap_ratio = options.overlap_ratio;
|
||||
arguments.mainloop.prefetch_ratio = options.prefetch_ratio;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0));
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0));
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0);
|
||||
result.gflops = options.gflops(avg_runtime_s);
|
||||
result.eff_bw = options.effective_bandwidth(avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD));
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
std::cout << " Effective bandwidth: " << result.eff_bw << " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
36
examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt
Normal file
36
examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
63_hopper_gemm_with_weight_prefetch.cu
|
||||
)
|
||||
82
examples/63_hopper_gemm_with_weight_prefetch/README.md
Normal file
82
examples/63_hopper_gemm_with_weight_prefetch/README.md
Normal file
@ -0,0 +1,82 @@
|
||||
# GEMM with L2 weight prefetch
|
||||
|
||||
A non-persistent warp specialized GEMM directed at low latency inference.
|
||||
|
||||
The kernel can optionally prefetch a portion of weights (operand `A`) into L2 cache while the
|
||||
rest of the warps are waiting on the previous kernel to finish writing and flush its memory.
|
||||
An example of this is normalization or reduction kernels that are immediately followed by a GEMM.
|
||||
|
||||
It exposes two runtime parameters:
|
||||
1. `overlap_ratio`: how early `griddepcontrol.launch_dependent_grids` is issued.
|
||||
Default is `0.5`, meaning after approximately half of K tiles are loaded by DMA warps.
|
||||
2. `prefetch_ratio`: what percentage of K tiles to prefetch.
|
||||
Default is `-1.0`, meaning prefetching will stop as soon as other DMA warps are past
|
||||
`griddepcontrol`.
|
||||
|
||||
It is highly recommended to auto-tune these parameters per GEMM and according to some end to end
|
||||
runtime (either an entire transformer layer or multiple, but probably not the entire model.)
|
||||
|
||||
TMA loads use non-default cache hints: `A` (weights) are loaded with `EvictFirst`, and `B` (activation)
|
||||
is loaded with `EvictLast`.
|
||||
|
||||
## Getting started
|
||||
To use this kernel in your own target, add this directory to your includes, and include the
|
||||
following headers from this example:
|
||||
|
||||
```cxx
|
||||
#include "collective/dispatch_policy_extra.hpp"
|
||||
#include "collective/builder.hpp"
|
||||
#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp"
|
||||
```
|
||||
|
||||
And then use either one of the new kernel schedules:
|
||||
|
||||
```cxx
|
||||
// Without separate warps for A and B
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetch;
|
||||
|
||||
// With separate warps for A and B
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA;
|
||||
```
|
||||
|
||||
The kernel with separate warps for A and B (
|
||||
`KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA`)
|
||||
is expected to be more performant than the other, especially since it allows the kernel to load
|
||||
weights into shmem ahead of the `griddepcontrol`.
|
||||
|
||||
As for other GEMM parameters, Thread Block Cluster larger than 1 CTA are not yet supported, and
|
||||
obviously the kernel layer implementation is warp specialized and uses the TMA, and other kernel
|
||||
layers or collectives require reimplementation.
|
||||
|
||||
## Example
|
||||
|
||||
Using the example is mostly straightforward.
|
||||
Just build, and run with your choice of `MNK`:
|
||||
|
||||
```bash
|
||||
./63_hopper_gemm_with_weight_prefetch --m=8192 --n=1 --k=8192
|
||||
```
|
||||
|
||||
You can also disable the overlap or try different overlap and prefetch ratios and see the
|
||||
difference:
|
||||
|
||||
```bash
|
||||
echo "Without overlap and prefetch"
|
||||
./63_hopper_gemm_with_weight_prefetch --o=-1.0 --p=-1.0
|
||||
|
||||
echo "Overlap ratio of 0.5, best effort prefetch"
|
||||
./63_hopper_gemm_with_weight_prefetch --o=0.5 --p=-1.0
|
||||
|
||||
echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
|
||||
./63_hopper_gemm_with_weight_prefetch --o=0.8 --p=0.7
|
||||
```
|
||||
|
||||
However, note that the example still runs a single GEMM, and most of the performance improvement
|
||||
is expected in end to end applications.
|
||||
|
||||
|
||||
## Limitations
|
||||
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
|
||||
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
|
||||
memory barrier before issuing every single TMA load, and in many cases this will slow down
|
||||
prefetching to the point of being almost ineffective.
|
||||
@ -0,0 +1,215 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "dispatch_policy_extra.hpp"
|
||||
#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
|
||||
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
|
||||
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutBTag>();
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<_1,_1,_1>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
|
||||
ElementA, ElementB, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,61 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
// Standard non-persistent kernel with a single producer warp, and one prefetch warp.
|
||||
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
|
||||
// while the producer warp is waiting on griddepcontrol.
|
||||
// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and
|
||||
// according to prefetch ratio.
|
||||
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { };
|
||||
|
||||
// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp.
|
||||
// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A`
|
||||
// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not
|
||||
// wait on griddepcontrol and loads immediately.
|
||||
struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { };
|
||||
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelSchedule;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@ -0,0 +1,867 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/arch/grid_dependency_control.h"
|
||||
|
||||
#include "dispatch_policy_extra.hpp"
|
||||
|
||||
#include "../pipeline/prefetch_pipeline_sm90.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1");
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
|
||||
static constexpr int PrefetchStages = 4;
|
||||
static constexpr int PrefetchInitialStages = 1;
|
||||
// This determines how much shmem we set aside for prefetch.
|
||||
// We don't reuse anything loaded by prefetcher, so we can keep
|
||||
// loading into the same place -- there will be a conflict when
|
||||
// writing, but it doesn't affect performance as much as the doors
|
||||
// that this opens.
|
||||
static constexpr int PrefetchStagesActual = 1;
|
||||
using PrefetcherPipeline = cutlass::PrefetchPipeline<PrefetchStages>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages);
|
||||
static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages);
|
||||
|
||||
using PrefetchSmemLayoutA = decltype(make_layout(make_shape(
|
||||
cute::Int<size<0>(SmemLayoutA{})>{},
|
||||
cute::Int<size<1>(SmemLayoutA{})>{},
|
||||
cute::Int<PrefetchStagesActual>{})));
|
||||
|
||||
static constexpr auto prefetch_smem_size = cute::cosize_v<PrefetchSmemLayoutA>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
|
||||
// Defined outside the class where it's used, to work around MSVC issues
|
||||
using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage<PrefetchStages>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128, _0> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, prefetch_smem_size> smem_prefetch;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
PrefetcherPipelineStorage prefetcher_pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
float overlap_ratio = 0.5;
|
||||
float prefetch_ratio = -1.0;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
float overlap_ratio = 0.5;
|
||||
float prefetch_ratio = -1.0;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.overlap_ratio,
|
||||
args.prefetch_ratio
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (args.overlap_ratio > 1.0) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (args.prefetch_ratio > 1.0) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl);
|
||||
}
|
||||
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorA const& gA_mkl,
|
||||
TensorB const& gB_nkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
|
||||
float overlap_ratio = mainloop_params.overlap_ratio;
|
||||
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_a
|
||||
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
// Applies the mapping from cta_tma_b
|
||||
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// We have to wait on dependent grids because of B.
|
||||
cutlass::arch::wait_on_dependent_grids();
|
||||
|
||||
// Signal prefetcher to stop
|
||||
prefetcher_pipeline.producer_arrive();
|
||||
|
||||
bool launch_dep_grids = false;
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class TensorA,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load_MK(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorA const& gA_mkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
bool disable_gdc = mainloop_params.overlap_ratio < 0.0;
|
||||
float overlap_ratio = mainloop_params.overlap_ratio;
|
||||
int launch_dep_grids_threshold = static_cast<int>(static_cast<float>(k_tile_count - 1) * overlap_ratio);
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_a
|
||||
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Don't wait on dependent grids when loading `A`, because
|
||||
// we assume `A` (weights) are static.
|
||||
|
||||
bool launch_dep_grids = false;
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||
launch_dep_grids = true;
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
if (!disable_gdc && !launch_dep_grids) {
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class TensorB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load_NK(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorB const& gB_nkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_b
|
||||
Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the prefetched kernel does not touch
|
||||
// unflushed global memory prior to this instruction
|
||||
cutlass::arch::wait_on_dependent_grids();
|
||||
|
||||
// Signal prefetcher to stop
|
||||
prefetcher_pipeline.producer_arrive();
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
class TensorA,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
prefetch_MK(
|
||||
Params const& mainloop_params,
|
||||
PrefetcherPipeline prefetcher_pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
TensorA const& gA_mkl,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0;
|
||||
float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio;
|
||||
int prefetch_iters = static_cast<int>(static_cast<float>(k_tile_count) * 0.5 * prefetch_ratio);
|
||||
prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages);
|
||||
|
||||
Tensor sA = make_tensor(
|
||||
make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
|
||||
// Applies the mapping from cta_tma_a
|
||||
Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prefetcher_stage = 0;
|
||||
uint32_t prefetcher_phase = 0;
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) {
|
||||
|
||||
if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) {
|
||||
break;
|
||||
}
|
||||
|
||||
prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages);
|
||||
using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType;
|
||||
BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage);
|
||||
|
||||
int write_stage = 0;
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
++k_tile_iter;
|
||||
++k_tile_iter;
|
||||
|
||||
prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase);
|
||||
}
|
||||
prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,117 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float overlap_ratio = 0.5f, prefetch_ratio = 0.5f;
|
||||
int iterations = 1000;
|
||||
int n = 64, m = 1280, k = 8192, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f);
|
||||
cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "63_hopper_gemm_with_weight_prefetch\n\n"
|
||||
<< " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n"
|
||||
<< " For more details please refer to the source file.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --p=<f32> Prefetch ratio\n"
|
||||
<< " --o=<f32> Overlap ratio\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "63_hopper_gemm_with_weight_prefetch" <<
|
||||
" --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
/// Compute effective bandwidth in GB/sec
|
||||
double effective_bandwidth(
|
||||
double runtime_s,
|
||||
size_t bytes_a,
|
||||
size_t bytes_b,
|
||||
size_t bytes_c,
|
||||
size_t bytes_d
|
||||
) const
|
||||
{
|
||||
static double const kBytesPerGiB = double(1ull << 30);
|
||||
|
||||
double bytes_in =
|
||||
(double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A
|
||||
(double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B
|
||||
(beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C
|
||||
double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D
|
||||
|
||||
double gb_total = (bytes_in + bytes_out) / kBytesPerGiB;
|
||||
return gb_total / runtime_s;
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,561 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "../collective/dispatch_policy_extra.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GEMM + Prefetch for the A tensor + (optional) split DMA warps
|
||||
template <
|
||||
class ProblemShape_,
|
||||
class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_,
|
||||
class TileScheduler_
|
||||
>
|
||||
class GemmUniversal<
|
||||
ProblemShape_,
|
||||
CollectiveMainloop_,
|
||||
CollectiveEpilogue_,
|
||||
TileScheduler_,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA> ||
|
||||
cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetch>
|
||||
>
|
||||
>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
|
||||
|
||||
static constexpr bool SplitWarps = cute::is_same_v<typename CollectiveMainloop_::DispatchPolicy::Schedule, KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA>;
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>,
|
||||
"TMA warp-specialized kernel does not support specializing the tile scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler = typename detail::TileSchedulerSelector<
|
||||
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
// Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union
|
||||
union TensorStorage {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16, _1> {
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) PrefetcherPipelineStorage prefetcher;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 1;
|
||||
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static
|
||||
Params
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
|
||||
// swap M/N
|
||||
get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
}
|
||||
return {
|
||||
args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
|
||||
};
|
||||
}
|
||||
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static
|
||||
size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static
|
||||
cutlass::Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
auto cluster_shape = ClusterShape{};
|
||||
auto tile_shape = TileShape{};
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
return TileScheduler::get_tiled_cta_shape_mnl(
|
||||
problem_shape_MNKL, tile_shape, cluster_shape);
|
||||
}
|
||||
|
||||
static dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params const& params, char* smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
#if defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
# define ENABLE_SM90_KERNEL_LEVEL 1
|
||||
#endif
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(ENABLE_SM90_KERNEL_LEVEL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
Consumer = 1,
|
||||
};
|
||||
// Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK.
|
||||
// Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused.
|
||||
// Both modes use Warp1 to prefetch.
|
||||
enum class ProducerWarpRole {
|
||||
Warp0 = 0,
|
||||
PrefetchMK = 1,
|
||||
Warp2 = 2,
|
||||
UnusedWarp = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
if (warp_group_role == WarpGroupRole::Producer && (
|
||||
producer_warp_role == ProducerWarpRole::Warp0 ||
|
||||
producer_warp_role == ProducerWarpRole::Warp2)) {
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer) {
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
bool should_prefetch = params.mainloop.prefetch_ratio > 0;
|
||||
using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline;
|
||||
typename PrefetcherPipeline::Params prefetcher_pipeline_params;
|
||||
prefetcher_pipeline_params.num_prefetchers = 1;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
prefetcher_pipeline_params.should_prefetch = should_prefetch;
|
||||
prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk;
|
||||
}
|
||||
PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params);
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
|
||||
epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
|
||||
}
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [&] () {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
// Non-prefetcher warps arrive and wait,
|
||||
// Prefetcher warp can go ahead without waiting.
|
||||
cute::cluster_arrive_relaxed();
|
||||
if (warp_group_role != WarpGroupRole::Producer ||
|
||||
producer_warp_role != ProducerWarpRole::PrefetchMK) {
|
||||
cute::cluster_wait();
|
||||
}
|
||||
return [] () {};
|
||||
}
|
||||
else {
|
||||
// __syncthreads() but only for non prefetcher warps
|
||||
if (should_prefetch) {
|
||||
|
||||
// Use a named barrier to let the prefetcher warp start loading into the L2
|
||||
// without waiting to sync with all other warps.
|
||||
// All other warps need to sync because the mainloop pipeline init
|
||||
// should be visible to all of them.
|
||||
// Prefetcher has its own barriers, and the only warps it would need to sync
|
||||
// with would be the DMA warps.
|
||||
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
|
||||
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z,
|
||||
/*reserved_named_barriers_*/ 14);
|
||||
// Prefetcher warp doesn't arrive on this barrier.
|
||||
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
|
||||
/*reserved_named_barriers_*/ 15);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
__syncwarp();
|
||||
prefetcher_arrive_barrier.arrive();
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Producer) {
|
||||
prefetcher_arrive_barrier.arrive_and_wait();
|
||||
cluster_arrive_barrier.arrive_and_wait();
|
||||
}
|
||||
else {
|
||||
prefetcher_arrive_barrier.arrive();
|
||||
cluster_arrive_barrier.arrive_and_wait();
|
||||
}
|
||||
} else {
|
||||
__syncthreads();
|
||||
}
|
||||
return [] () {};
|
||||
}
|
||||
} ();
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
TiledMma tiled_mma;
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
// Compute m_coord, n_coord, and l_coord with their post-tiled shapes
|
||||
auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Get pipeline iterators and increments from tensor shapes
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
if (producer_warp_role == ProducerWarpRole::Warp0) {
|
||||
if constexpr(SplitWarps) {
|
||||
collective_mainloop.load_NK(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gB_nkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
else {
|
||||
collective_mainloop.load(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gA_mkl, gB_nkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
// Update starting mainloop pipeline state for the pipeline drain
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
|
||||
if (collective_epilogue.is_producer_load_needed()) {
|
||||
// Ensure warp is converged before issuing epilogue loads
|
||||
__syncwarp();
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord,
|
||||
tiled_mma,
|
||||
lane_idx,
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
}
|
||||
}
|
||||
else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) {
|
||||
collective_mainloop.load_MK(
|
||||
params.mainloop,
|
||||
mainloop_pipeline,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gA_mkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
// Update starting mainloop pipeline state for the pipeline drain
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
// Make sure mainloop consumer has been waited upon before issuing epilogue load
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) {
|
||||
collective_mainloop.prefetch_MK(
|
||||
params.mainloop,
|
||||
prefetcher_pipeline,
|
||||
mainloop_pipe_producer_state,
|
||||
gA_mkl,
|
||||
blk_coord,
|
||||
k_tile_iter, k_tile_count,
|
||||
lane_idx,
|
||||
block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
}
|
||||
}
|
||||
else if (warp_group_role == WarpGroupRole::Consumer) {
|
||||
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
accumulators,
|
||||
k_tile_count,
|
||||
warp_group_thread_idx,
|
||||
shared_storage.tensors.mainloop,
|
||||
params.mainloop
|
||||
);
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(
|
||||
mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
k_tile_count
|
||||
);
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline,
|
||||
epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord,
|
||||
accumulators,
|
||||
tiled_mma,
|
||||
warp_group_thread_idx,
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline,
|
||||
epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cute/container/array.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// MSVC work-around
|
||||
template <int Stages>
|
||||
struct PrefetcherPipelineSharedStorage {
|
||||
using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
using Barrier = cutlass::arch::ClusterBarrier;
|
||||
|
||||
TransactionBarrier tma_barrier[Stages];
|
||||
Barrier producer_ready_barrier;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction
|
||||
// barrier providing control over the number of concurrent outstanding TMA loads.
|
||||
// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset.
|
||||
// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks
|
||||
// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads.
|
||||
template <int Stages_>
|
||||
class PrefetchPipeline {
|
||||
public :
|
||||
static constexpr uint32_t Stages = Stages_;
|
||||
using SharedStorage = detail::PrefetcherPipelineSharedStorage<Stages>;
|
||||
|
||||
using TransactionBarrier = typename SharedStorage::TransactionBarrier;
|
||||
using Barrier = typename SharedStorage::Barrier;
|
||||
using PrefetcherBarrierType = typename TransactionBarrier::ValueType;
|
||||
|
||||
struct Params {
|
||||
uint32_t transaction_bytes = 0;
|
||||
uint32_t num_prefetchers = 1;
|
||||
bool should_prefetch = false;
|
||||
};
|
||||
|
||||
// Constructor
|
||||
CUTLASS_DEVICE
|
||||
PrefetchPipeline(SharedStorage& storage, Params params)
|
||||
: params_(params)
|
||||
, tma_barrier_ptr_(&storage.tma_barrier[0])
|
||||
, producer_ready_barrier_ptr_(&storage.producer_ready_barrier) {
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
if (params.should_prefetch && lane_predicate) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Stages; ++i) {
|
||||
tma_barrier_ptr_[i].init(params.num_prefetchers);
|
||||
}
|
||||
producer_ready_barrier_ptr_[0].init(1);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void producer_arrive() {
|
||||
if (params_.should_prefetch) {
|
||||
producer_ready_barrier_ptr_[0].arrive();
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool have_producers_arrived() {
|
||||
if (params_.should_prefetch) {
|
||||
uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0);
|
||||
auto barrier_status = static_cast<BarrierStatus>(barrier_status_);
|
||||
if (barrier_status == BarrierStatus::WaitDone) {
|
||||
return true; // exit prefetcher loop
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) {
|
||||
if (params_.should_prefetch) {
|
||||
if (should_wait) {
|
||||
tma_barrier_ptr_[stage].wait(phase ^ 1);
|
||||
}
|
||||
tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) {
|
||||
if (params_.should_prefetch) {
|
||||
stage++;
|
||||
if (stage == Stages) {
|
||||
stage = 0;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void prefetcher_tail(uint32_t stage, uint32_t phase) {
|
||||
if (params_.should_prefetch) {
|
||||
// Wait on any already-issued loads
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < stage; ++i) {
|
||||
tma_barrier_ptr_[i].wait(phase);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) {
|
||||
return reinterpret_cast<PrefetcherBarrierType*>(&tma_barrier_ptr_[stage]);
|
||||
}
|
||||
|
||||
private :
|
||||
TransactionBarrier* tma_barrier_ptr_ = nullptr;
|
||||
Barrier* producer_ready_barrier_ptr_ = nullptr;
|
||||
Params params_;
|
||||
|
||||
};
|
||||
|
||||
} // end namespace cutlass
|
||||
@ -140,6 +140,9 @@ foreach(EXAMPLE
|
||||
57_hopper_grouped_gemm
|
||||
58_ada_fp8_gemm
|
||||
59_ampere_gather_scatter_conv
|
||||
61_hopper_gemm_with_topk_and_softmax
|
||||
62_hopper_sparse_gemm
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -186,8 +186,8 @@ int main(int argc, char** argv)
|
||||
return -1;
|
||||
}
|
||||
// Equivalent check to the above
|
||||
if (not weakly_compatible(block_shape, tensor_shape)) {
|
||||
std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl;
|
||||
if (not evenly_divides(tensor_shape, block_shape)) {
|
||||
std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user