Improve sm90 mixed dtype kernel (#1883)
This commit is contained in:
657
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu
Normal file
657
examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu
Normal file
@ -0,0 +1,657 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example shows how to perform INT4 x BF16 GEMM and scale up the INT4 weight during dequantization.
|
||||
|
||||
The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap
|
||||
A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue,
|
||||
as illustrated in this example.
|
||||
|
||||
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
|
||||
|
||||
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
|
||||
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads
|
||||
4 groups of 2 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a for thread-value layout).
|
||||
If the narrow type is INT4 and tensor is major in K dim, only 8 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
|
||||
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
|
||||
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro
|
||||
OPTIMIZE_WEIGHT_LAYOUT is set to 1.
|
||||
|
||||
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
|
||||
|
||||
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
|
||||
|
||||
If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k].
|
||||
|
||||
The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size
|
||||
equal to the gemm problem K.
|
||||
|
||||
Limitations:
|
||||
1) Only supports INT4 x { FP16, BF16 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported.
|
||||
2) The INT4 weights have additional encoding requirements.
|
||||
3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major.
|
||||
4) The scales must have the same layout and groupsize.
|
||||
5) The groupsize must be greater or equal to the tile shape k.
|
||||
6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the
|
||||
operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations.
|
||||
We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands.
|
||||
|
||||
Optimizing suggestions:
|
||||
1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space).
|
||||
|
||||
Examples:
|
||||
|
||||
Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0)
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0
|
||||
|
||||
Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire
|
||||
matrix (group size is the same as the gemm k dimension).
|
||||
$ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "unfused_weight_dequantize.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#define OPTIMIZE_WEIGHT_LAYOUT 1
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
#endif
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale; // only for verify
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C matrix.
|
||||
// We can enable this if beta == 0 by changing ElementC to void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule // This is the only epi supporting the required swap + transpose.
|
||||
>::CollectiveOp;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Reordered, AlignmentB,
|
||||
#else
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose, AlignmentB,
|
||||
#endif
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideC = typename GemmKernelScaleOnly::StrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::StrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideC_ref stride_C_ref;
|
||||
StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
LayoutB_Reordered layout_B_reordered;
|
||||
#endif
|
||||
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
StrideS_ref stride_S_ref;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementA> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 10;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
int l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("g", g);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "55_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
|
||||
<< " --g=<int> The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k * l;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
Options const& options) {
|
||||
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
|
||||
float const max_dequant_val = 4.f;
|
||||
float const min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_zero(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
Options const& options) {
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto layout_B = make_layout(shape_B, stride_B);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
|
||||
block_A.reset(a_coord.product());
|
||||
block_B.reset(b_coord.product());
|
||||
block_B_dq.reset(b_coord.product());
|
||||
block_C.reset(c_coord.product());
|
||||
block_D.reset(c_coord.product());
|
||||
block_ref_D.reset(c_coord.product());
|
||||
|
||||
block_scale.reset(scale_k * options.l * options.n);
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
print("\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Args>
|
||||
Args args_from_options(Options const& options)
|
||||
{
|
||||
// Swap the A and B tensors, as well as problem shapes here.
|
||||
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
{block_B.get(), layout_B_reordered, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
|
||||
#else
|
||||
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
|
||||
#endif
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{block_A.get(), stride_A, block_B_dq.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// compare_reference
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
// {$nv-internal-release begin}
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
// {$nv-internal-release end}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
if (options.g == options.k) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
} else {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
}
|
||||
run<GemmScaleOnly>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -47,6 +47,14 @@
|
||||
|
||||
Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest.
|
||||
|
||||
As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory.
|
||||
This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads
|
||||
4 groups of 4 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n32-a for thread-value layout).
|
||||
If the narrow type is INT4 and tensor is major in K dim, only 16 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput.
|
||||
If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized
|
||||
tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro
|
||||
OPTIMIZE_WEIGHT_LAYOUT is set to 1.
|
||||
|
||||
It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size).
|
||||
|
||||
Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled.
|
||||
@ -104,9 +112,12 @@
|
||||
#include "helper.h"
|
||||
#include "unfused_weight_dequantize.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#define OPTIMIZE_WEIGHT_LAYOUT 1
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -130,6 +141,17 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // M
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
#endif
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale; // only for verify
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
@ -172,7 +194,11 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8> >, LayoutB_Transpose, AlignmentB,
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered, AlignmentB,
|
||||
#else
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose, AlignmentB,
|
||||
#endif
|
||||
ElementA, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
@ -190,8 +216,6 @@ using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
using StrideC = typename GemmKernelScaleOnly::StrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::StrideD;
|
||||
|
||||
@ -211,6 +235,10 @@ StrideD stride_D;
|
||||
StrideD_ref stride_D_ref;
|
||||
uint64_t seed;
|
||||
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
LayoutB_Reordered layout_B_reordered;
|
||||
#endif
|
||||
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
StrideS stride_S;
|
||||
@ -399,7 +427,7 @@ bool unify_quant_encoding(
|
||||
d = out;
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2);
|
||||
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -461,10 +489,10 @@ bool initialize_zero(
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
@ -472,6 +500,8 @@ void initialize(Options const& options) {
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l));
|
||||
stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l));
|
||||
|
||||
auto layout_B = make_layout(shape_B, stride_B);
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
@ -496,14 +526,22 @@ void initialize(Options const& options) {
|
||||
initialize_packed_scale(block_scale, block_scale_packed);
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto layout_B = make_layout(shape_b, stride_B);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
print("\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
@ -515,7 +553,11 @@ Args args_from_options(Options const& options)
|
||||
return Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.n, options.m, options.k, options.l},
|
||||
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
|
||||
#if OPTIMIZE_WEIGHT_LAYOUT
|
||||
{block_B_modified.get(), layout_B_reordered, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
|
||||
#else
|
||||
{block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g},
|
||||
#endif
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
}
|
||||
@ -581,6 +623,7 @@ bool verify(Options const& options) {
|
||||
ElementD const epsilon(1e-2f);
|
||||
ElementD const non_zero_floor(1e-4f);
|
||||
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
@ -68,3 +68,14 @@ cutlass_example_add_executable(
|
||||
TEST_SCALE_RESIDUE
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
55_hopper_int4_bf16_gemm
|
||||
55_hopper_int4_bf16_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
TEST_SCALE_RESIDUE
|
||||
# TEST_ALPHA_BETA
|
||||
)
|
||||
|
||||
@ -31,7 +31,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdint>
|
||||
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
122
examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp
Normal file
122
examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp
Normal file
@ -0,0 +1,122 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
||||
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
||||
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
||||
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
||||
template<class MmaType>
|
||||
auto compute_memory_reordering_atom()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// 1. Choose an MMA atom to access TV layout and MN shape
|
||||
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
||||
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<MmaType, MmaType, float, Shape<_64,_16,_32>>());
|
||||
using MmaTraits = MMA_Traits<MmaAtom>;
|
||||
auto shape_MK = select<0,2>(typename MmaTraits::Shape_MNK{});
|
||||
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
||||
|
||||
// 2. Create a single warp's TV layout from that of the whole MMA
|
||||
// Note: this assumes A is partitioned between warps along M mode
|
||||
auto tile_TV_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
||||
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tile_TV_warp));
|
||||
|
||||
// 3. Invert warp's TV layout to get MK layout (m,k -> thr,val)
|
||||
auto shape_MK_warp = shape_div(shape_MK, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
||||
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(shape_MK_warp);
|
||||
|
||||
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
||||
auto tv_to_offset = make_ordered_layout(shape(tv_layout_mma_warp), Step<_1,_0>{});
|
||||
auto layout_atom = composition(tv_to_offset, mk_layout_mma_warp);
|
||||
|
||||
return layout_atom;
|
||||
}
|
||||
|
||||
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
||||
__global__ void reorder_tensor_kernel(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> src,
|
||||
cute::Tensor<EngineDst, LayoutDst> dst)
|
||||
{
|
||||
auto i = blockIdx.x;
|
||||
auto k = blockIdx.y;
|
||||
for (int j = threadIdx.x; j < cute::size<1>(src); j += blockDim.x) {
|
||||
dst(i,j,k) = src(i,j,k);
|
||||
}
|
||||
}
|
||||
|
||||
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> t_src,
|
||||
cute::Tensor<EngineDst, LayoutDst> t_dst)
|
||||
{
|
||||
using T = typename EngineDst::value_type;
|
||||
static_assert(cute::is_same_v<cute::remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
||||
using V = cute::uint_bit_t<cute::max(8, cute::sizeof_bits_v<T>)>;
|
||||
|
||||
cute::Tensor v_src = cute::recast<V>(t_src);
|
||||
cute::Tensor v_dst = cute::recast<V>(t_dst);
|
||||
|
||||
int threads = 256;
|
||||
dim3 blocks{unsigned(cute::size<0>(v_src)), unsigned(cute::size<2>(v_src)), 1u};
|
||||
|
||||
reorder_tensor_kernel<<<blocks, threads>>>(v_src, v_dst);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T const* src,
|
||||
LayoutSrc const& layout_src,
|
||||
T * dst,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
reorder_tensor(make_tensor(src, layout_src),
|
||||
make_tensor(dst, layout_dst));
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T * data,
|
||||
LayoutSrc const& layout_src,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
cutlass::DeviceAllocation<T> temp(cute::size(layout_src));
|
||||
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
||||
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(cute::size(layout_src)));
|
||||
}
|
||||
@ -340,7 +340,7 @@ auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
|
||||
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
@ -198,13 +198,22 @@ is_major(Stride = {}) {
|
||||
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(cute::remove_pointer_t<Stride>{})))>::value;
|
||||
}
|
||||
|
||||
template<int ModeIndex, class Shape, class Stride>
|
||||
constexpr bool
|
||||
is_major(cute::Layout<Shape,Stride> = {}) {
|
||||
return is_major<ModeIndex>(Stride{});
|
||||
}
|
||||
|
||||
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
|
||||
template<class StrideA>
|
||||
constexpr
|
||||
auto
|
||||
stride_to_layout_tag_A() {
|
||||
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
||||
if constexpr (is_major<0, StrideA>()) { // M major
|
||||
if constexpr (cute::is_layout<InternalStrideA>::value) {
|
||||
return stride_to_layout_tag_A<decltype(cute::stride(InternalStrideA{}))>();
|
||||
}
|
||||
else if constexpr (is_major<0, StrideA>()) { // M major
|
||||
return layout::ColumnMajor{};
|
||||
}
|
||||
// Specialize for sparse layout
|
||||
@ -224,7 +233,11 @@ template<class StrideB>
|
||||
constexpr
|
||||
auto
|
||||
stride_to_layout_tag_B() {
|
||||
if constexpr (is_major<0, StrideB>()) { // N major
|
||||
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
||||
if constexpr (cute::is_layout<InternalStrideB>::value) {
|
||||
return stride_to_layout_tag_B<decltype(cute::stride(InternalStrideB{}))>();
|
||||
}
|
||||
else if constexpr (is_major<0, StrideB>()) { // N major
|
||||
return layout::RowMajor{};
|
||||
}
|
||||
else { // K major
|
||||
@ -238,7 +251,11 @@ template<class StrideC>
|
||||
constexpr
|
||||
auto
|
||||
stride_to_layout_tag_C() {
|
||||
if constexpr (is_major<0, StrideC>()) { // M major
|
||||
using InternalStrideC = cute::remove_pointer_t<StrideC>;
|
||||
if constexpr (cute::is_layout<InternalStrideC>::value) {
|
||||
return stride_to_layout_tag_C<decltype(cute::stride(InternalStrideC{}))>();
|
||||
}
|
||||
else if constexpr (is_major<0, StrideC>()) { // M major
|
||||
return layout::ColumnMajor{};
|
||||
}
|
||||
else { // N major
|
||||
@ -349,28 +366,25 @@ get_output_alignment_bits() {
|
||||
return 128;
|
||||
}
|
||||
|
||||
|
||||
// Return the shape that is associated with stride-1 mode, or 1 if not found
|
||||
template<typename Shape, typename Stride>
|
||||
CUTLASS_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_contiguous_shape(Shape const & shape, Stride const & stride) {
|
||||
using namespace cute;
|
||||
auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; });
|
||||
return get<decltype(idx)::value>(append(flatten(shape), _1{}));
|
||||
}
|
||||
|
||||
// Check if tensor shape satisfies a given major alignment
|
||||
// Check if tensor layout satisfies a given major alignment
|
||||
template<int Alignment, class Shape, class Stride>
|
||||
CUTLASS_HOST_DEVICE constexpr
|
||||
bool
|
||||
check_alignment(Shape const & shape, Stride const & stride) {
|
||||
return is_major<0>(stride)
|
||||
? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0
|
||||
: get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0;
|
||||
check_alignment(cute::Layout<Shape,Stride> const& layout) {
|
||||
// Condition: shape must divide by Alignment without rounding
|
||||
bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast<Alignment>(layout));
|
||||
// Condition: every dynamic stride must be a multiple of Alignment
|
||||
bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static<decltype(s)>::value || (s % Alignment == 0); });
|
||||
return shape_check && stride_check;
|
||||
}
|
||||
|
||||
// Check if tensor shape satisfies a given major alignment
|
||||
// Check if tensor layout satisfies a given major alignment
|
||||
template<int Alignment, class Shape, class Stride>
|
||||
CUTLASS_HOST_DEVICE constexpr
|
||||
bool
|
||||
check_alignment(Shape const& shape, Stride const& stride) {
|
||||
return check_alignment<Alignment>(cute::make_layout(shape, stride));
|
||||
}
|
||||
|
||||
template<int B, int M, int S>
|
||||
CUTLASS_HOST_DEVICE constexpr
|
||||
|
||||
@ -327,13 +327,23 @@ public:
|
||||
if constexpr (is_destination_supported) {
|
||||
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
|
||||
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
|
||||
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(shape, StrideD{});
|
||||
if constexpr (cute::is_same_v<CopyOpS2G, SM90_TMA_STORE_IM2COL>) { // ignore L stride for implicit gemm
|
||||
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(take<0,2>(shape), take<0,2>(StrideD{}));
|
||||
}
|
||||
else {
|
||||
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(shape, StrideD{});
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
|
||||
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(shape, StrideC{});
|
||||
if constexpr (cute::is_same_v<CopyOpG2S, SM90_TMA_LOAD_IM2COL>) { // ignore L stride for implicit gemm
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(take<0,2>(shape), take<0,2>(StrideC{}));
|
||||
}
|
||||
else {
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(shape, StrideC{});
|
||||
}
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
|
||||
@ -409,8 +409,18 @@ public:
|
||||
|
||||
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
|
||||
|
||||
using GmemLayoutATag = GmemLayoutATag_;
|
||||
using GmemLayoutBTag = GmemLayoutBTag_;
|
||||
template<class T>
|
||||
static auto get_stride(T const& t) {
|
||||
if constexpr (not cute::is_layout<T>::value) {
|
||||
return t;
|
||||
}
|
||||
else {
|
||||
return cute::stride(t);
|
||||
}
|
||||
}
|
||||
|
||||
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
|
||||
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
|
||||
|
||||
using ElementPairA = cute::conditional_t<IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementPairA_>;
|
||||
using ElementPairB = cute::conditional_t<!IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementPairB_>;
|
||||
@ -464,8 +474,8 @@ public:
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
|
||||
|
||||
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
|
||||
using StrideA = TagToStrideA_t<GmemLayoutATag>;
|
||||
using StrideB = TagToStrideB_t<GmemLayoutBTag>;
|
||||
using StrideA = cute::conditional_t<cute::is_layout<GmemLayoutATag_>::value, GmemLayoutATag_, TagToStrideA_t<GmemLayoutATag>>;
|
||||
using StrideB = cute::conditional_t<cute::is_layout<GmemLayoutBTag_>::value, GmemLayoutBTag_, TagToStrideB_t<GmemLayoutBTag>>;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
|
||||
@ -182,6 +182,7 @@ public:
|
||||
using InternalSmemLayoutAtomB = cute::conditional_t<!SwapAB, SmemLayoutAtomB, SmemLayoutAtomA>;
|
||||
using InternalSmemCopyAtomA = cute::conditional_t<!SwapAB, SmemCopyAtomA, SmemCopyAtomB>;
|
||||
using InternalSmemCopyAtomB = cute::conditional_t<!SwapAB, SmemCopyAtomB, SmemCopyAtomA>;
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
@ -228,14 +229,25 @@ public:
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
InternalSmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
InternalSmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
template<class LayoutAtom, class TileShape, class Stride>
|
||||
static constexpr
|
||||
CUTLASS_HOST_DEVICE
|
||||
auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) {
|
||||
if constexpr (not cute::is_layout<Stride>::value) {
|
||||
return tile_to_shape(
|
||||
layout_atom,
|
||||
append(tile_shape, Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{});
|
||||
}
|
||||
else {
|
||||
auto gmem_tile = composition(stride, tile_shape);
|
||||
return make_layout_like(append(gmem_tile, make_layout(Int<DispatchPolicy::Stages>{}, 0)));
|
||||
}
|
||||
}
|
||||
|
||||
using SmemLayoutA = decltype(get_smem_layout(InternalSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalStrideA{}));
|
||||
using SmemLayoutB = decltype(get_smem_layout(InternalSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalStrideB{}));
|
||||
|
||||
// It is assumed that the scales and zero-points share the same smem layout
|
||||
using SmemLayoutScale = decltype(tile_to_shape(
|
||||
@ -381,6 +393,18 @@ public:
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
template<class Shape, class Stride>
|
||||
static constexpr
|
||||
CUTLASS_HOST_DEVICE
|
||||
auto get_gmem_layout(Shape const& shape, Stride const& stride) {
|
||||
if constexpr (not cute::is_layout<Stride>::value) {
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
else {
|
||||
return stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
private:
|
||||
@ -394,10 +418,14 @@ public:
|
||||
TransformB_>;
|
||||
|
||||
public:
|
||||
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using LayoutA = decltype(get_gmem_layout(repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}));
|
||||
using LayoutB = decltype(get_gmem_layout(repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}));
|
||||
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90<TmaElementA>(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
|
||||
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), LayoutA{}),
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
@ -419,7 +447,7 @@ public:
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
|
||||
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), LayoutB{}),
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
@ -431,6 +459,8 @@ public:
|
||||
int group_size;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
|
||||
InternalStrideA dA;
|
||||
InternalStrideB dB;
|
||||
};
|
||||
|
||||
//
|
||||
@ -469,8 +499,8 @@ public:
|
||||
dB = args.dA;
|
||||
}
|
||||
|
||||
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA));
|
||||
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB));
|
||||
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), get_gmem_layout(make_shape(M,K,L), dA));
|
||||
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), get_gmem_layout(make_shape(N,K,L), dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90<TmaElementA>(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
@ -490,7 +520,7 @@ public:
|
||||
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 };
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
auto scale_k = (K + args.group_size - 1) / args.group_size;
|
||||
@ -505,7 +535,7 @@ public:
|
||||
_1{}); // mcast along N mode for this M load, if any
|
||||
|
||||
if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
|
||||
}
|
||||
else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS));
|
||||
@ -515,7 +545,7 @@ public:
|
||||
SmemLayoutScale{}(_,_,cute::Int<0>{}),
|
||||
ScaleTileShape{},
|
||||
_1{}); // mcast along N mode for this M load, if any
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
|
||||
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
|
||||
} else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in to_underlying_arguments.");
|
||||
}
|
||||
@ -533,33 +563,37 @@ public:
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
bool check_aligned_A = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(get_gmem_layout(cute::make_shape(M,K,L), args.dA));
|
||||
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
bool check_aligned_B = cutlass::detail::check_alignment<min_tma_aligned_elements_B>(get_gmem_layout(cute::make_shape(N,K,L), args.dB));
|
||||
|
||||
bool check_aligned_S = true;
|
||||
bool check_aligned_Z = true;
|
||||
bool check_mode_args = true;
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
implementable = implementable && (args.ptr_S == nullptr);
|
||||
implementable = implementable && (args.ptr_Z == nullptr);
|
||||
check_mode_args = check_mode_args && (args.ptr_S == nullptr);
|
||||
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
|
||||
}
|
||||
else if constexpr (ModeHasScales) {
|
||||
const int scale_mn = SwapAB ? N : M;
|
||||
const int scale_k = (K + args.group_size - 1) / args.group_size;
|
||||
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
|
||||
implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
|
||||
implementable = implementable && args.group_size != 0;
|
||||
implementable = implementable && (args.ptr_S != nullptr);
|
||||
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
|
||||
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
|
||||
check_mode_args = check_mode_args && args.group_size != 0;
|
||||
check_mode_args = check_mode_args && (args.ptr_S != nullptr);
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
implementable = implementable && (args.ptr_Z == nullptr);
|
||||
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
|
||||
}
|
||||
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
|
||||
implementable = implementable && (args.ptr_Z != nullptr);
|
||||
check_aligned_Z = cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), args.dS);
|
||||
check_mode_args = check_mode_args && (args.ptr_Z != nullptr);
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
|
||||
@ -569,10 +603,23 @@ public:
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
if (!check_mode_args) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n");
|
||||
}
|
||||
return implementable;
|
||||
if (!check_aligned_A) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
if (!check_aligned_B) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
if (!check_aligned_S) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
if (!check_aligned_Z) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
|
||||
return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
@ -618,8 +665,8 @@ public:
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(get_gmem_layout(make_shape(M,K,L), mainloop_params.dA))); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(get_gmem_layout(make_shape(N,K,L), mainloop_params.dB))); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
@ -680,8 +727,6 @@ public:
|
||||
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
|
||||
}
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
|
||||
@ -748,8 +793,10 @@ public:
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
if (cute::elect_one_sync()) {
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
}
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
||||
// Nothing extra to do.
|
||||
@ -920,6 +967,12 @@ public:
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
|
||||
warpgroup_arrive();
|
||||
// (V,M) x (V,N) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
warpgroup_commit_batch();
|
||||
|
||||
if (k_block < K_BLOCK_MAX - 2) { // prefetch next block
|
||||
copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view,
|
||||
partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage);
|
||||
@ -927,11 +980,6 @@ public:
|
||||
if (k_block < K_BLOCK_MAX - 1) {
|
||||
transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1);
|
||||
}
|
||||
warpgroup_arrive();
|
||||
// (V,M) x (V,N) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
|
||||
--k_tile_count;
|
||||
|
||||
@ -70,7 +70,7 @@ using cutlass::detail::StrideToLayoutTagC_t;
|
||||
template<int ModeIndex, class Stride>
|
||||
constexpr bool
|
||||
is_major(Stride = {}) {
|
||||
return ::cutlass::detail::is_major<ModeIndex, Stride>();
|
||||
return ::cutlass::detail::is_major<ModeIndex>(Stride{});
|
||||
}
|
||||
|
||||
template<class Stride>
|
||||
|
||||
Reference in New Issue
Block a user