Improve sm90 mixed dtype kernel (#1883)

This commit is contained in:
Sergey Klevtsov
2024-10-17 17:06:38 -07:00
committed by GitHub
parent 755194a7bd
commit 08101d9d0c
11 changed files with 994 additions and 80 deletions

View File

@ -0,0 +1,657 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
This example shows how to perform INT4 x BF16 GEMM and scale up the INT4 weight during dequantization.
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
as illustrated in this example.
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads
4 groups of 2 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a for thread-value layout).
If the narrow type is INT4 and tensor is major in K dim, only 8 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro
OPTIMIZE_WEIGHT_LAYOUT is set to 1.
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size
equal to the gemm problem K.
Limitations:
1) Only supports INT4 x { FP16, BF16 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported.
2) The INT4 weights have additional encoding requirements.
3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
4) The scales must have the same layout and groupsize.
5) The groupsize must be greater or equal to the tile shape k.
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
Optimizing suggestions:
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
Examples:
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
matrix (group size is the same as the gemm k dimension).
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "helper.h"
#include "unfused_weight_dequantize.hpp"
#include "packed_scale.hpp"
#include "reorder_utils.hpp"
using namespace cute;
#define OPTIMIZE_WEIGHT_LAYOUT 1
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
using MmaType = cutlass::bfloat16_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
// A matrix configuration
using ElementA = MmaType; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = QuantType; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
#if OPTIMIZE_WEIGHT_LAYOUT
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
// It specifies the reordering within a single warp's fragment
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
#endif
using ElementScale = MmaType;
using ElementZero = ElementScale; // only for verify
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementAccumulator,
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C matrix.
// We can enable this if beta == 0 by changing ElementC to void below.
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
>::CollectiveOp;
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
#if OPTIMIZE_WEIGHT_LAYOUT
cute::tuple<ElementB, ElementScale>, LayoutB_Reordered, AlignmentB,
#else
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose, AlignmentB,
#endif
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleOnly,
CollectiveEpilogue
>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideC = typename GemmKernelScaleOnly::StrideC;
using StrideD = typename GemmKernelScaleOnly::StrideD;
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
//
// Data members
//
/// Initialization
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideC_ref stride_C_ref;
StrideD stride_D;
StrideD_ref stride_D_ref;
uint64_t seed;
#if OPTIMIZE_WEIGHT_LAYOUT
LayoutB_Reordered layout_B_reordered;
#endif
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
StrideS_ref stride_S_ref;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementA> block_B_dq;
cutlass::DeviceAllocation<ElementScale> block_scale;
cutlass::DeviceAllocation<ElementZero> block_zero;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 10;
int m = 5120, n = 4096, k = 4096;
int g = 128;
int l = 1;
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("g", g);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "55_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
<< " --g=<int> The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <typename Element>
bool initialize_quant_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <class Element>
bool initialize_scale(
cutlass::DeviceAllocation<Element>& block,
Options const& options) {
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
float const max_dequant_val = 4.f;
float const min_dequant_val = 0.5f;
float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
return true;
}
template <class Element>
bool initialize_zero(
cutlass::DeviceAllocation<Element>& block,
Options const& options) {
std::vector<Element> stage(block.size(), Element(0.0f));
block.copy_from_host(stage.data());
return true;
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
// Reverse stride here due to swap and transpose
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
auto layout_B = make_layout(shape_B, stride_B);
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
block_A.reset(a_coord.product());
block_B.reset(b_coord.product());
block_B_dq.reset(b_coord.product());
block_C.reset(c_coord.product());
block_D.reset(c_coord.product());
block_ref_D.reset(c_coord.product());
block_scale.reset(scale_k * options.l * options.n);
block_zero.reset(scale_k * options.l * options.n);
initialize_tensor(block_A, seed + 2022);
initialize_quant_tensor(block_B, seed + 2021);
initialize_tensor(block_C, seed + 2020);
initialize_scale(block_scale, options);
initialize_zero(block_zero, options);
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
#if OPTIMIZE_WEIGHT_LAYOUT
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
print("Quantized tensor layout: ");
print(layout_B_reordered);
print("\n");
#endif
}
/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Args>
Args args_from_options(Options const& options)
{
// Swap the A and B tensors, as well as problem shapes here.
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
#if OPTIMIZE_WEIGHT_LAYOUT
{block_B.get(), layout_B_reordered, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
#else
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
#endif
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
bool verify(Options const& options) {
//
// Compute reference output
//
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
MmaType, LayoutA, AlignmentA,
MmaType, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopRef,
CollectiveEpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
};
// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
// compare_reference
ElementD const epsilon(1e-2f);
ElementD const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options<typename Gemm::Arguments>(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 90.
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA 12 or newer.\n";
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major < 9) {
std::cerr
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
<< "later (compute capability 90 or greater).\n";
return 0;
}
// {$nv-internal-release begin}
else if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
return 0;
}
// {$nv-internal-release end}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
if (options.g == options.k) {
std::cout << "Running in per-column scale mode." << std::endl;
} else {
std::cout << "Running in group scale mode." << std::endl;
}
run<GemmScaleOnly>(options);
#endif
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -47,6 +47,14 @@
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads
4 groups of 4 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n32-a for thread-value layout).
If the narrow type is INT4 and tensor is major in K dim, only 16 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro
OPTIMIZE_WEIGHT_LAYOUT is set to 1.
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
@ -104,9 +112,12 @@
#include "helper.h"
#include "unfused_weight_dequantize.hpp"
#include "packed_scale.hpp"
#include "reorder_utils.hpp"
using namespace cute;
#define OPTIMIZE_WEIGHT_LAYOUT 1
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -130,6 +141,17 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // M
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
#if OPTIMIZE_WEIGHT_LAYOUT
// Define the CuTe layout for reoredered quantized tensor B
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
// It specifies the reordering within a single warp's fragment
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
#endif
using ElementScale = MmaType;
using ElementZero = ElementScale; // only for verify
using LayoutScale = cutlass::layout::RowMajor;
@ -172,7 +194,11 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8> >, LayoutB_Transpose, AlignmentB,
#if OPTIMIZE_WEIGHT_LAYOUT
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered, AlignmentB,
#else
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose, AlignmentB,
#endif
ElementA, LayoutA_Transpose, AlignmentA,
ElementAccumulator,
TileShape, ClusterShape,
@ -190,8 +216,6 @@ using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideC = typename GemmKernelScaleOnly::StrideC;
using StrideD = typename GemmKernelScaleOnly::StrideD;
@ -211,6 +235,10 @@ StrideD stride_D;
StrideD_ref stride_D_ref;
uint64_t seed;
#if OPTIMIZE_WEIGHT_LAYOUT
LayoutB_Reordered layout_B_reordered;
#endif
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
StrideS stride_S;
@ -399,7 +427,7 @@ bool unify_quant_encoding(
d = out;
}
cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2);
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
return true;
}
@ -461,10 +489,10 @@ bool initialize_zero(
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {
auto shape_b = cute::make_shape(options.n, options.k, options.l);
auto shape_B = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
// Reverse stride here due to swap and transpose
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
@ -472,6 +500,8 @@ void initialize(Options const& options) {
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
auto layout_B = make_layout(shape_B, stride_B);
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
@ -496,14 +526,22 @@ void initialize(Options const& options) {
initialize_packed_scale(block_scale, block_scale_packed);
initialize_zero(block_zero, options);
auto layout_B = make_layout(shape_b, stride_B);
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
#if OPTIMIZE_WEIGHT_LAYOUT
// Repeat the reorder layout atom to tile the whole tensor shape
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
print("Quantized tensor layout: ");
print(layout_B_reordered);
print("\n");
#endif
}
/// Populates a Gemm::Arguments structure from the given commandline options
@ -515,7 +553,11 @@ Args args_from_options(Options const& options)
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
#if OPTIMIZE_WEIGHT_LAYOUT
{block_B_modified.get(), layout_B_reordered, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
#else
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
#endif
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
@ -581,6 +623,7 @@ bool verify(Options const& options) {
ElementD const epsilon(1e-2f);
ElementD const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
}

View File

@ -68,3 +68,14 @@ cutlass_example_add_executable(
TEST_SCALE_RESIDUE
# TEST_ALPHA_BETA
)
cutlass_example_add_executable(
55_hopper_int4_bf16_gemm
55_hopper_int4_bf16_gemm.cu
TEST_COMMAND_OPTIONS
TEST_DIRECT_BATCHED
TEST_SCALE_PERCOL
TEST_SCALE_GROUP
TEST_SCALE_RESIDUE
# TEST_ALPHA_BETA
)

View File

@ -31,7 +31,6 @@
#pragma once
#include <iostream>
#include <cstdint>
#include "cutlass/float8.h"

View File

@ -0,0 +1,122 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cutlass/util/device_memory.h"
// Given a type of MMA instruction, compute a memory reordering atom that places all values
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
template<class MmaType>
auto compute_memory_reordering_atom()
{
using namespace cute;
// 1. Choose an MMA atom to access TV layout and MN shape
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<MmaType, MmaType, float, Shape<_64,_16,_32>>());
using MmaTraits = MMA_Traits<MmaAtom>;
auto shape_MK = select<0,2>(typename MmaTraits::Shape_MNK{});
auto tv_layout_mma = typename MmaTraits::ALayout{};
// 2. Create a single warp's TV layout from that of the whole MMA
// Note: this assumes A is partitioned between warps along M mode
auto tile_TV_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tile_TV_warp));
// 3. Invert warp's TV layout to get MK layout (m,k -> thr,val)
auto shape_MK_warp = shape_div(shape_MK, size(typename MmaTraits::ThrID{}) / Int<32>{});
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(shape_MK_warp);
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
auto tv_to_offset = make_ordered_layout(shape(tv_layout_mma_warp), Step<_1,_0>{});
auto layout_atom = composition(tv_to_offset, mk_layout_mma_warp);
return layout_atom;
}
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
__global__ void reorder_tensor_kernel(
cute::Tensor<EngineSrc, LayoutSrc> src,
cute::Tensor<EngineDst, LayoutDst> dst)
{
auto i = blockIdx.x;
auto k = blockIdx.y;
for (int j = threadIdx.x; j < cute::size<1>(src); j += blockDim.x) {
dst(i,j,k) = src(i,j,k);
}
}
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
void reorder_tensor(
cute::Tensor<EngineSrc, LayoutSrc> t_src,
cute::Tensor<EngineDst, LayoutDst> t_dst)
{
using T = typename EngineDst::value_type;
static_assert(cute::is_same_v<cute::remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
using V = cute::uint_bit_t<cute::max(8, cute::sizeof_bits_v<T>)>;
cute::Tensor v_src = cute::recast<V>(t_src);
cute::Tensor v_dst = cute::recast<V>(t_dst);
int threads = 256;
dim3 blocks{unsigned(cute::size<0>(v_src)), unsigned(cute::size<2>(v_src)), 1u};
reorder_tensor_kernel<<<blocks, threads>>>(v_src, v_dst);
CUDA_CHECK(cudaDeviceSynchronize());
}
// In-place version
template<class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T const* src,
LayoutSrc const& layout_src,
T * dst,
LayoutDst const& layout_dst)
{
reorder_tensor(make_tensor(src, layout_src),
make_tensor(dst, layout_dst));
}
// In-place version
template<class T, class LayoutSrc, class LayoutDst>
void reorder_tensor(
T * data,
LayoutSrc const& layout_src,
LayoutDst const& layout_dst)
{
cutlass::DeviceAllocation<T> temp(cute::size(layout_src));
reorder_tensor(data, layout_src, temp.get(), layout_dst);
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(cute::size(layout_src)));
}

View File

@ -340,7 +340,7 @@ auto
all_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
} else {
return f(t);
}

View File

@ -198,13 +198,22 @@ is_major(Stride = {}) {
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(cute::remove_pointer_t<Stride>{})))>::value;
}
template<int ModeIndex, class Shape, class Stride>
constexpr bool
is_major(cute::Layout<Shape,Stride> = {}) {
return is_major<ModeIndex>(Stride{});
}
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
template<class StrideA>
constexpr
auto
stride_to_layout_tag_A() {
using InternalStrideA = cute::remove_pointer_t<StrideA>;
if constexpr (is_major<0, StrideA>()) { // M major
if constexpr (cute::is_layout<InternalStrideA>::value) {
return stride_to_layout_tag_A<decltype(cute::stride(InternalStrideA{}))>();
}
else if constexpr (is_major<0, StrideA>()) { // M major
return layout::ColumnMajor{};
}
// Specialize for sparse layout
@ -224,7 +233,11 @@ template<class StrideB>
constexpr
auto
stride_to_layout_tag_B() {
if constexpr (is_major<0, StrideB>()) { // N major
using InternalStrideB = cute::remove_pointer_t<StrideB>;
if constexpr (cute::is_layout<InternalStrideB>::value) {
return stride_to_layout_tag_B<decltype(cute::stride(InternalStrideB{}))>();
}
else if constexpr (is_major<0, StrideB>()) { // N major
return layout::RowMajor{};
}
else { // K major
@ -238,7 +251,11 @@ template<class StrideC>
constexpr
auto
stride_to_layout_tag_C() {
if constexpr (is_major<0, StrideC>()) { // M major
using InternalStrideC = cute::remove_pointer_t<StrideC>;
if constexpr (cute::is_layout<InternalStrideC>::value) {
return stride_to_layout_tag_C<decltype(cute::stride(InternalStrideC{}))>();
}
else if constexpr (is_major<0, StrideC>()) { // M major
return layout::ColumnMajor{};
}
else { // N major
@ -349,28 +366,25 @@ get_output_alignment_bits() {
return 128;
}
// Return the shape that is associated with stride-1 mode, or 1 if not found
template<typename Shape, typename Stride>
CUTLASS_HOST_DEVICE constexpr
auto
get_contiguous_shape(Shape const & shape, Stride const & stride) {
using namespace cute;
auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; });
return get<decltype(idx)::value>(append(flatten(shape), _1{}));
}
// Check if tensor shape satisfies a given major alignment
// Check if tensor layout satisfies a given major alignment
template<int Alignment, class Shape, class Stride>
CUTLASS_HOST_DEVICE constexpr
bool
check_alignment(Shape const & shape, Stride const & stride) {
return is_major<0>(stride)
? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0
: get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0;
check_alignment(cute::Layout<Shape,Stride> const& layout) {
// Condition: shape must divide by Alignment without rounding
bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast<Alignment>(layout));
// Condition: every dynamic stride must be a multiple of Alignment
bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static<decltype(s)>::value || (s % Alignment == 0); });
return shape_check && stride_check;
}
// Check if tensor shape satisfies a given major alignment
// Check if tensor layout satisfies a given major alignment
template<int Alignment, class Shape, class Stride>
CUTLASS_HOST_DEVICE constexpr
bool
check_alignment(Shape const& shape, Stride const& stride) {
return check_alignment<Alignment>(cute::make_layout(shape, stride));
}
template<int B, int M, int S>
CUTLASS_HOST_DEVICE constexpr

View File

@ -327,13 +327,23 @@ public:
if constexpr (is_destination_supported) {
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(shape, StrideD{});
if constexpr (cute::is_same_v<CopyOpS2G, SM90_TMA_STORE_IM2COL>) { // ignore L stride for implicit gemm
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(take<0,2>(shape), take<0,2>(StrideD{}));
}
else {
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(shape, StrideD{});
}
}
if constexpr (not cute::is_void_v<ElementC>) {
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(shape, StrideC{});
if constexpr (cute::is_same_v<CopyOpG2S, SM90_TMA_LOAD_IM2COL>) { // ignore L stride for implicit gemm
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(take<0,2>(shape), take<0,2>(StrideC{}));
}
else {
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(shape, StrideC{});
}
}
if (!implementable) {

View File

@ -409,8 +409,18 @@ public:
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
using GmemLayoutATag = GmemLayoutATag_;
using GmemLayoutBTag = GmemLayoutBTag_;
template<class T>
static auto get_stride(T const& t) {
if constexpr (not cute::is_layout<T>::value) {
return t;
}
else {
return cute::stride(t);
}
}
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
using ElementPairA = cute::conditional_t<IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementPairA_>;
using ElementPairB = cute::conditional_t<!IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementPairB_>;
@ -464,8 +474,8 @@ public:
using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideA = TagToStrideA_t<GmemLayoutATag>;
using StrideB = TagToStrideB_t<GmemLayoutBTag>;
using StrideA = cute::conditional_t<cute::is_layout<GmemLayoutATag_>::value, GmemLayoutATag_, TagToStrideA_t<GmemLayoutATag>>;
using StrideB = cute::conditional_t<cute::is_layout<GmemLayoutBTag_>::value, GmemLayoutBTag_, TagToStrideB_t<GmemLayoutBTag>>;
using CollectiveOp = CollectiveMma<
DispatchPolicy,

View File

@ -182,6 +182,7 @@ public:
using InternalSmemLayoutAtomB = cute::conditional_t<!SwapAB, SmemLayoutAtomB, SmemLayoutAtomA>;
using InternalSmemCopyAtomA = cute::conditional_t<!SwapAB, SmemCopyAtomA, SmemCopyAtomB>;
using InternalSmemCopyAtomB = cute::conditional_t<!SwapAB, SmemCopyAtomB, SmemCopyAtomA>;
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
@ -228,14 +229,25 @@ public:
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
InternalSmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
InternalSmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
template<class LayoutAtom, class TileShape, class Stride>
static constexpr
CUTLASS_HOST_DEVICE
auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) {
if constexpr (not cute::is_layout<Stride>::value) {
return tile_to_shape(
layout_atom,
append(tile_shape, Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{});
}
else {
auto gmem_tile = composition(stride, tile_shape);
return make_layout_like(append(gmem_tile, make_layout(Int<DispatchPolicy::Stages>{}, 0)));
}
}
using SmemLayoutA = decltype(get_smem_layout(InternalSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalStrideA{}));
using SmemLayoutB = decltype(get_smem_layout(InternalSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalStrideB{}));
// It is assumed that the scales and zero-points share the same smem layout
using SmemLayoutScale = decltype(tile_to_shape(
@ -381,6 +393,18 @@ public:
uint32_t mma_promotion_interval = 4;
};
template<class Shape, class Stride>
static constexpr
CUTLASS_HOST_DEVICE
auto get_gmem_layout(Shape const& shape, Stride const& stride) {
if constexpr (not cute::is_layout<Stride>::value) {
return make_layout(shape, stride);
}
else {
return stride;
}
}
// Device side kernel params
struct Params {
private:
@ -394,10 +418,14 @@ public:
TransformB_>;
public:
// Assumption: StrideA is congruent with Problem_MK
using LayoutA = decltype(get_gmem_layout(repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}));
using LayoutB = decltype(get_gmem_layout(repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}));
using TMA_A = decltype(make_tma_copy_A_sm90<TmaElementA>(
GmemTiledCopyA{},
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), LayoutA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{})); // mcast along N mode for this M load, if any
@ -419,7 +447,7 @@ public:
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy_B_sm90(
GmemTiledCopyB{},
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), LayoutB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{})); // mcast along M mode for this N load, if any
@ -431,6 +459,8 @@ public:
int group_size;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
InternalStrideA dA;
InternalStrideB dB;
};
//
@ -469,8 +499,8 @@ public:
dB = args.dA;
}
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA));
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB));
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), get_gmem_layout(make_shape(M,K,L), dA));
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), get_gmem_layout(make_shape(N,K,L), dB));
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90<TmaElementA>(
GmemTiledCopyA{},
tensor_a,
@ -490,7 +520,7 @@ public:
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 };
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
}
else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;
@ -505,7 +535,7 @@ public:
_1{}); // mcast along N mode for this M load, if any
if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) {
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
}
else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS));
@ -515,7 +545,7 @@ public:
SmemLayoutScale{}(_,_,cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in to_underlying_arguments.");
}
@ -533,33 +563,37 @@ public:
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
bool check_aligned_A = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(get_gmem_layout(cute::make_shape(M,K,L), args.dA));
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
bool check_aligned_B = cutlass::detail::check_alignment<min_tma_aligned_elements_B>(get_gmem_layout(cute::make_shape(N,K,L), args.dB));
bool check_aligned_S = true;
bool check_aligned_Z = true;
bool check_mode_args = true;
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
implementable = implementable && (args.ptr_S == nullptr);
implementable = implementable && (args.ptr_Z == nullptr);
check_mode_args = check_mode_args && (args.ptr_S == nullptr);
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.group_size - 1) / args.group_size;
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
implementable = implementable && args.group_size != 0;
implementable = implementable && (args.ptr_S != nullptr);
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
check_mode_args = check_mode_args && args.group_size != 0;
check_mode_args = check_mode_args && (args.ptr_S != nullptr);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
implementable = implementable && (args.ptr_Z == nullptr);
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.ptr_Z != nullptr);
check_aligned_Z = cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), args.dS);
check_mode_args = check_mode_args && (args.ptr_Z != nullptr);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
@ -569,10 +603,23 @@ public:
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
if (!check_mode_args) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n");
}
return implementable;
if (!check_aligned_A) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n");
}
if (!check_aligned_B) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n");
}
if (!check_aligned_S) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n");
}
if (!check_aligned_Z) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n");
}
return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
@ -618,8 +665,8 @@ public:
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(get_gmem_layout(make_shape(M,K,L), mainloop_params.dA))); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(get_gmem_layout(make_shape(N,K,L), mainloop_params.dB))); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
@ -680,8 +727,6 @@ public:
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
}
int lane_predicate = cute::elect_one_sync();
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
@ -748,8 +793,10 @@ public:
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
if (cute::elect_one_sync()) {
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
}
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do.
@ -920,6 +967,12 @@ public:
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
if (k_block < K_BLOCK_MAX - 2) { // prefetch next block
copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view,
partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage);
@ -927,11 +980,6 @@ public:
if (k_block < K_BLOCK_MAX - 1) {
transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1);
}
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
}
--k_tile_count;

View File

@ -70,7 +70,7 @@ using cutlass::detail::StrideToLayoutTagC_t;
template<int ModeIndex, class Stride>
constexpr bool
is_major(Stride = {}) {
return ::cutlass::detail::is_major<ModeIndex, Stride>();
return ::cutlass::detail::is_major<ModeIndex>(Stride{});
}
template<class Stride>