diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu new file mode 100644 index 00000000..9346734a --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -0,0 +1,657 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform INT4 x BF16 GEMM and scale up the INT4 weight during dequantization. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. + This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads + 4 groups of 2 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a for thread-value layout). + If the narrow type is INT4 and tensor is major in K dim, only 8 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput. + If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized + tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro + OPTIMIZE_WEIGHT_LAYOUT is set to 1. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size + equal to the gemm problem K. + + Limitations: + 1) Only supports INT4 x { FP16, BF16 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported. + 2) The INT4 weights have additional encoding requirements. + 3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 4) The scales must have the same layout and groupsize. + 5) The groupsize must be greater or equal to the tile shape k. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "unfused_weight_dequantize.hpp" +#include "packed_scale.hpp" +#include "reorder_utils.hpp" + +using namespace cute; + +#define OPTIMIZE_WEIGHT_LAYOUT 1 + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; + +#if OPTIMIZE_WEIGHT_LAYOUT +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); +#endif + +using ElementScale = MmaType; +using ElementZero = ElementScale; // only for verify +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, +#if OPTIMIZE_WEIGHT_LAYOUT + cute::tuple, LayoutB_Reordered, AlignmentB, +#else + cute::tuple, LayoutB_Transpose, AlignmentB, +#endif + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +#if OPTIMIZE_WEIGHT_LAYOUT +LayoutB_Reordered layout_B_reordered; +#endif + +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + Options const& options) { + + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + float const max_dequant_val = 4.f; + float const min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + Options const& options) { + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + + auto shape_B = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto layout_B = make_layout(shape_B, stride_B); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + +#if OPTIMIZE_WEIGHT_LAYOUT + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + reorder_tensor(block_B.get(), layout_B, layout_B_reordered); + + print("Quantized tensor layout: "); + print(layout_B_reordered); + print("\n"); +#endif +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +Args args_from_options(Options const& options) +{ +// Swap the A and B tensors, as well as problem shapes here. + + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, +#if OPTIMIZE_WEIGHT_LAYOUT + {block_B.get(), layout_B_reordered, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, +#else + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, +#endif + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // {$nv-internal-release begin} + else if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + // {$nv-internal-release end} + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index 138f7a04..eee10e01 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -47,6 +47,14 @@ Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. + This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads + 4 groups of 4 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n32-a for thread-value layout). + If the narrow type is INT4 and tensor is major in K dim, only 16 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput. + If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized + tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the macro + OPTIMIZE_WEIGHT_LAYOUT is set to 1. + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. @@ -104,9 +112,12 @@ #include "helper.h" #include "unfused_weight_dequantize.hpp" #include "packed_scale.hpp" +#include "reorder_utils.hpp" using namespace cute; +#define OPTIMIZE_WEIGHT_LAYOUT 1 + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -130,6 +141,17 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // M using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; + +#if OPTIMIZE_WEIGHT_LAYOUT +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); +#endif + using ElementScale = MmaType; using ElementZero = ElementScale; // only for verify using LayoutScale = cutlass::layout::RowMajor; @@ -172,7 +194,11 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui // The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - cute::tuple >, LayoutB_Transpose, AlignmentB, +#if OPTIMIZE_WEIGHT_LAYOUT + cute::tuple>, LayoutB_Reordered, AlignmentB, +#else + cute::tuple>, LayoutB_Transpose, AlignmentB, +#endif ElementA, LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape, @@ -190,8 +216,6 @@ using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; -using StrideA = cutlass::detail::TagToStrideA_t; -using StrideB = cutlass::detail::TagToStrideB_t; using StrideC = typename GemmKernelScaleOnly::StrideC; using StrideD = typename GemmKernelScaleOnly::StrideD; @@ -211,6 +235,10 @@ StrideD stride_D; StrideD_ref stride_D_ref; uint64_t seed; +#if OPTIMIZE_WEIGHT_LAYOUT +LayoutB_Reordered layout_B_reordered; +#endif + using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; using StrideS_ref = cutlass::detail::TagToStrideB_t; StrideS stride_S; @@ -399,7 +427,7 @@ bool unify_quant_encoding( d = out; } - cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2); + cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack); return true; } @@ -461,10 +489,10 @@ bool initialize_zero( /// Initialize operands to be used in the GEMM and reference GEMM void initialize(Options const& options) { - auto shape_b = cute::make_shape(options.n, options.k, options.l); + auto shape_B = cute::make_shape(options.n, options.k, options.l); int const scale_k = (options.k + options.g - 1) / options.g; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); // Reverse stride here due to swap and transpose stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); @@ -472,6 +500,8 @@ void initialize(Options const& options) { stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + auto layout_B = make_layout(shape_B, stride_B); + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); @@ -496,14 +526,22 @@ void initialize(Options const& options) { initialize_packed_scale(block_scale, block_scale_packed); initialize_zero(block_zero, options); - auto layout_B = make_layout(shape_b, stride_B); - auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + + #if OPTIMIZE_WEIGHT_LAYOUT + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered); + + print("Quantized tensor layout: "); + print(layout_B_reordered); + print("\n"); +#endif } /// Populates a Gemm::Arguments structure from the given commandline options @@ -515,7 +553,11 @@ Args args_from_options(Options const& options) return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g}, +#if OPTIMIZE_WEIGHT_LAYOUT + {block_B_modified.get(), layout_B_reordered, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g}, +#else + {block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g}, +#endif {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } @@ -581,6 +623,7 @@ bool verify(Options const& options) { ElementD const epsilon(1e-2f); ElementD const non_zero_floor(1e-4f); bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + return passed; } diff --git a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt index a9753ed1..23dca4f3 100644 --- a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -68,3 +68,14 @@ cutlass_example_add_executable( TEST_SCALE_RESIDUE # TEST_ALPHA_BETA ) + + cutlass_example_add_executable( + 55_hopper_int4_bf16_gemm + 55_hopper_int4_bf16_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_RESIDUE + # TEST_ALPHA_BETA + ) diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp index 294d4261..7d732dcd 100644 --- a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -31,7 +31,6 @@ #pragma once -#include #include #include "cutlass/float8.h" diff --git a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp new file mode 100644 index 00000000..2be42551 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp @@ -0,0 +1,122 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" + +#include "cutlass/util/device_memory.h" + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +template +auto compute_memory_reordering_atom() +{ + using namespace cute; + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto shape_MK = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + + // 2. Create a single warp's TV layout from that of the whole MMA + // Note: this assumes A is partitioned between warps along M mode + auto tile_TV_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tile_TV_warp)); + + // 3. Invert warp's TV layout to get MK layout (m,k -> thr,val) + auto shape_MK_warp = shape_div(shape_MK, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(shape_MK_warp); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto tv_to_offset = make_ordered_layout(shape(tv_layout_mma_warp), Step<_1,_0>{}); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_warp); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor src, + cute::Tensor dst) +{ + auto i = blockIdx.x; + auto k = blockIdx.y; + for (int j = threadIdx.x; j < cute::size<1>(src); j += blockDim.x) { + dst(i,j,k) = src(i,j,k); + } +} + +template +void reorder_tensor( + cute::Tensor t_src, + cute::Tensor t_dst) +{ + using T = typename EngineDst::value_type; + static_assert(cute::is_same_v, T>, "Type mismatch"); + using V = cute::uint_bit_t)>; + + cute::Tensor v_src = cute::recast(t_src); + cute::Tensor v_dst = cute::recast(t_dst); + + int threads = 256; + dim3 blocks{unsigned(cute::size<0>(v_src)), unsigned(cute::size<2>(v_src)), 1u}; + + reorder_tensor_kernel<<>>(v_src, v_dst); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + reorder_tensor(make_tensor(src, layout_src), + make_tensor(dst, layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + cutlass::DeviceAllocation temp(cute::size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(cute::size(layout_src))); +} \ No newline at end of file diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index c87ce682..5a70f590 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -340,7 +340,7 @@ auto all_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq{}); + return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq{}); } else { return f(t); } diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 216ba402..cbed61f6 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -198,13 +198,22 @@ is_major(Stride = {}) { return cute::is_constant<1, decltype(cute::front(cute::get(cute::remove_pointer_t{})))>::value; } +template +constexpr bool +is_major(cute::Layout = {}) { + return is_major(Stride{}); +} + // Note : This method can be used for deducing the Layout Tag of A, C, D Matrices template constexpr auto stride_to_layout_tag_A() { using InternalStrideA = cute::remove_pointer_t; - if constexpr (is_major<0, StrideA>()) { // M major + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_A(); + } + else if constexpr (is_major<0, StrideA>()) { // M major return layout::ColumnMajor{}; } // Specialize for sparse layout @@ -224,7 +233,11 @@ template constexpr auto stride_to_layout_tag_B() { - if constexpr (is_major<0, StrideB>()) { // N major + using InternalStrideB = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_B(); + } + else if constexpr (is_major<0, StrideB>()) { // N major return layout::RowMajor{}; } else { // K major @@ -238,7 +251,11 @@ template constexpr auto stride_to_layout_tag_C() { - if constexpr (is_major<0, StrideC>()) { // M major + using InternalStrideC = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_C(); + } + else if constexpr (is_major<0, StrideC>()) { // M major return layout::ColumnMajor{}; } else { // N major @@ -349,28 +366,25 @@ get_output_alignment_bits() { return 128; } - -// Return the shape that is associated with stride-1 mode, or 1 if not found -template -CUTLASS_HOST_DEVICE constexpr -auto -get_contiguous_shape(Shape const & shape, Stride const & stride) { - using namespace cute; - auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; }); - return get(append(flatten(shape), _1{})); -} - -// Check if tensor shape satisfies a given major alignment +// Check if tensor layout satisfies a given major alignment template CUTLASS_HOST_DEVICE constexpr bool -check_alignment(Shape const & shape, Stride const & stride) { - return is_major<0>(stride) - ? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0 - : get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0; +check_alignment(cute::Layout const& layout) { + // Condition: shape must divide by Alignment without rounding + bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast(layout)); + // Condition: every dynamic stride must be a multiple of Alignment + bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static::value || (s % Alignment == 0); }); + return shape_check && stride_check; } -// Check if tensor shape satisfies a given major alignment +// Check if tensor layout satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(Shape const& shape, Stride const& stride) { + return check_alignment(cute::make_layout(shape, stride)); +} template CUTLASS_HOST_DEVICE constexpr diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index b2fa4e35..b96c4aea 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -327,13 +327,23 @@ public: if constexpr (is_destination_supported) { constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = cutlass::detail::check_alignment(shape, StrideD{}); + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = cutlass::detail::check_alignment(shape, StrideD{}); + } } if constexpr (not cute::is_void_v) { constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } } if (!implementable) { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index a4cc7686..8657aad2 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -409,8 +409,18 @@ public: static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; - using GmemLayoutATag = GmemLayoutATag_; - using GmemLayoutBTag = GmemLayoutBTag_; + template + static auto get_stride(T const& t) { + if constexpr (not cute::is_layout::value) { + return t; + } + else { + return cute::stride(t); + } + } + + using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{})); + using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{})); using ElementPairA = cute::conditional_t, ElementPairA_>; using ElementPairB = cute::conditional_t, ElementPairB_>; @@ -464,8 +474,8 @@ public: using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; // We pack the scale data with the operand that will be optionally scaled and converted before MMA. - using StrideA = TagToStrideA_t; - using StrideB = TagToStrideB_t; + using StrideA = cute::conditional_t::value, GmemLayoutATag_, TagToStrideA_t>; + using StrideB = cute::conditional_t::value, GmemLayoutBTag_, TagToStrideB_t>; using CollectiveOp = CollectiveMma< DispatchPolicy, diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 8c98d15c..0871d26c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -182,6 +182,7 @@ public: using InternalSmemLayoutAtomB = cute::conditional_t; using InternalSmemCopyAtomA = cute::conditional_t; using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. static constexpr bool ConvertF32toTF32A = cute::is_same_v; @@ -228,14 +229,25 @@ public: static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape( - InternalSmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - using SmemLayoutB = decltype(tile_to_shape( - InternalSmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + template + static constexpr + CUTLASS_HOST_DEVICE + auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) { + if constexpr (not cute::is_layout::value) { + return tile_to_shape( + layout_atom, + append(tile_shape, Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}); + } + else { + auto gmem_tile = composition(stride, tile_shape); + return make_layout_like(append(gmem_tile, make_layout(Int{}, 0))); + } + } + + using SmemLayoutA = decltype(get_smem_layout(InternalSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalStrideA{})); + using SmemLayoutB = decltype(get_smem_layout(InternalSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalStrideB{})); // It is assumed that the scales and zero-points share the same smem layout using SmemLayoutScale = decltype(tile_to_shape( @@ -381,6 +393,18 @@ public: uint32_t mma_promotion_interval = 4; }; + template + static constexpr + CUTLASS_HOST_DEVICE + auto get_gmem_layout(Shape const& shape, Stride const& stride) { + if constexpr (not cute::is_layout::value) { + return make_layout(shape, stride); + } + else { + return stride; + } + } + // Device side kernel params struct Params { private: @@ -394,10 +418,14 @@ public: TransformB_>; public: + // Assumption: StrideA is congruent with Problem_MK + using LayoutA = decltype(get_gmem_layout(repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{})); + using LayoutB = decltype(get_gmem_layout(repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{})); + using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, - make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), LayoutA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), TileShape{}, ClusterShape{})); // mcast along N mode for this M load, if any @@ -419,7 +447,7 @@ public: // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, - make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), LayoutB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), TileShape{}, ClusterShape{})); // mcast along M mode for this N load, if any @@ -431,6 +459,8 @@ public: int group_size; uint32_t tma_transaction_bytes = TmaTransactionBytes; int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + InternalStrideA dA; + InternalStrideB dB; }; // @@ -469,8 +499,8 @@ public: dB = args.dA; } - Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA)); - Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB)); + Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), get_gmem_layout(make_shape(M,K,L), dA)); + Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), get_gmem_layout(make_shape(N,K,L), dB)); typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( GmemTiledCopyA{}, tensor_a, @@ -490,7 +520,7 @@ public: uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB }; } else if constexpr (ModeHasScales) { auto scale_k = (K + args.group_size - 1) / args.group_size; @@ -505,7 +535,7 @@ public: _1{}); // mcast along N mode for this M load, if any if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB }; } else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS)); @@ -515,7 +545,7 @@ public: SmemLayoutScale{}(_,_,cute::Int<0>{}), ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB }; } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } @@ -533,33 +563,37 @@ public: constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - - bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + bool check_aligned_A = cutlass::detail::check_alignment(get_gmem_layout(cute::make_shape(M,K,L), args.dA)); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + bool check_aligned_B = cutlass::detail::check_alignment(get_gmem_layout(cute::make_shape(N,K,L), args.dB)); + + bool check_aligned_S = true; + bool check_aligned_Z = true; + bool check_mode_args = true; if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - implementable = implementable && (args.ptr_S == nullptr); - implementable = implementable && (args.ptr_Z == nullptr); + check_mode_args = check_mode_args && (args.ptr_S == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); } else if constexpr (ModeHasScales) { const int scale_mn = SwapAB ? N : M; const int scale_k = (K + args.group_size - 1) / args.group_size; constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); - implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); - implementable = implementable && args.group_size != 0; - implementable = implementable && (args.ptr_S != nullptr); + check_aligned_S = cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), args.dS); + check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + check_mode_args = check_mode_args && args.group_size != 0; + check_mode_args = check_mode_args && (args.ptr_S != nullptr); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - implementable = implementable && (args.ptr_Z == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); - implementable = implementable && (args.ptr_Z != nullptr); + check_aligned_Z = cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), args.dS); + check_mode_args = check_mode_args && (args.ptr_Z != nullptr); } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); @@ -569,10 +603,23 @@ public: static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); } - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + if (!check_mode_args) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n"); } - return implementable; + if (!check_aligned_A) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_B) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_S) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_Z) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n"); + } + + return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z; } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; @@ -618,8 +665,8 @@ public: // TMA requires special handling of strides to deal with coord codomain mapping // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(get_gmem_layout(make_shape(M,K,L), mainloop_params.dA))); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(get_gmem_layout(make_shape(N,K,L), mainloop_params.dB))); // (n,k,l) // Make tiled views, defer the slice Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) @@ -680,8 +727,6 @@ public: static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); } - int lane_predicate = cute::elect_one_sync(); - Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) @@ -748,8 +793,10 @@ public: BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); int write_stage = smem_pipe_write.index(); - if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Nothing extra to do. @@ -920,6 +967,12 @@ public: // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); @@ -927,11 +980,6 @@ public: if (k_block < K_BLOCK_MAX - 1) { transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } - warpgroup_arrive(); - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - warpgroup_commit_batch(); } --k_tile_count; diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 33d3943f..ac288e3e 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -70,7 +70,7 @@ using cutlass::detail::StrideToLayoutTagC_t; template constexpr bool is_major(Stride = {}) { - return ::cutlass::detail::is_major(); + return ::cutlass::detail::is_major(Stride{}); } template