Files
cutlass/examples/cute/tutorial/blackwell/01_mma_sm100.cu
2025-07-21 22:03:55 -04:00

615 lines
31 KiB
Plaintext

/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// CuTe Tutorial for SM100 Programming
// This tutorial series demonstrates CuTe Blackwell capabilities that are frequently used
// throughout CUTLASS. The goal is to familiarize developers with CuTe SM100 interfaces.
//
// The tutorial series is split into five stages:
// * 01_mma_sm100.cu: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction.
// * 02_mma_tma_sm100.cu: Simple Blackwell SM100 GEMM using tcgen05.mma and TMA instructions.
// * 03_mma_tma_multicast_sm100.cu: Blackwell SM100 GEMM using tcgen05.mma and Multicast TMA.
// * 04_mma_tma_2sm_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma and 2SM Multicast TMA.
// * 05_mma_tma_epi_sm100.cu: Blackwell SM100 GEMM with 2SM tcgen05.mma, 2SM TMA mainloop, and TMA epilogue.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <cstdio>
// Use Thrust to handle host/device allocations
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
// Cutlass includes
#include <cutlass/half.h> // F16 data type
#include <cutlass/util/print_error.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/cluster_sm90.hpp> // CuTe functions for querying the details of cluster launched
#include <cute/numeric/integral_constant.hpp> // Compile time in constants such as _1, _256 etc.
#include <cute/algorithm/cooperative_copy.hpp> // Auto vectorized copy operation
#include <cute/arch/tmem_allocator_sm100.hpp> // TMEM allocator for SM100
// Tutorial helpers
#include "example_utils.hpp"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Tutorial 01: Simple Blackwell SM100 GEMM using a tcgen05.mma instruction
//
///////////////////////////////////////////////////////////////////////////////////////////////////
// The goal of this tutorial is to show the CuTe interface for tcgen05.mma and tcgen05.ld operations.
// We will implement a GEMM operation: D (f32) = beta * C (F32) + alpha * A (F16) * B (F16) where:
// - Matrix A is MxK, K-major (BLAS transpose T, row-major)
// - Matrix B is NxK, K-major (BLAS transpose N, column-major)
// - Matrices C and D are MxN, N-major (BLAS row-major)
//
// This GEMM kernel performs the following steps:
// 1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) for one MmaTile
// using auto-vectorizing copy operations.
// 2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
// 3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
// 4. Read C matrix from global memory (GMEM) to register (RMEM).
// 5. Apply alpha and beta scaling to the MMA accumulator and C matrix.
// 6. Store D matrix from registers (RMEM) to global memory (GMEM).
//
// SM100 tcgen05.mma instructions operate as follows:
// - Read matrix A from SMEM or TMEM
// - Read matrix B from SMEM
// - Write accumulator to TMEM
// The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
//
// The tcgen05.mma instruction requires an Instruction Descriptor that encodes A, B, and Accumulator types
// and the MMA's M and N dimensions.
// The A and B matrices that are read from SMEM need to be provided to MMA instructions as SMEM Descriptors.
// These are the A and B fragments of the tcgen05.mma in CuTe terminology.
// CuTe provides these descriptors transparently in the instruction and fragments, shown in this tutorial.
//
// The MMA details:
// We use the tcgen05.mma.f16 instruction (F16xF16 = F32) that performs a 128x256x16 MMA
// operation. F32 accumulator type is chosen since both C and D matrices use F32.
// This example uses F16xF16 = F32 MMA where:
// TypeA = cutlass::half_t; // MMA A Data Type
// TypeB = cutlass::half_t; // MMA B Data Type
// TypeC = float; // MMA C Data Type
// TypeD = float; // MMA D Data Type
// TypeAccumulator = float; // Both TypeC and TypeD are float, so we use float accumulator type
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// The shared memory buffers for A and B matrices.
template <class TypeA, // Tensor A data type
class TypeB, // Tensor B data type
class ASmemLayout, // (MmaA, NumMma_M, NumMma_K, ...)
class BSmemLayout> // (MmaB, NumMma_N, NumMma_K, ...)
struct SharedStorage
{
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM
alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation
CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); }
};
// The device kernel
template <class SharedStorage,
class ATensor, class BTensor, class CTensor, class DTensor,
class MmaTiler_MNK, class TiledMMA, class ClusterShape_MNK,
class Alpha, class Beta>
__global__ static
void
gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
BTensor mB, // (Gemm_N, Gemm_K)
CTensor mC, // (Gemm_M, Gemm_N)
DTensor mD, // (Gemm_M, Gemm_N)
MmaTiler_MNK mma_tiler, // <MmaTile_M, MmaTile_N, MmaTile_K>
TiledMMA tiled_mma, // < Mma_M, Mma_N, Mma_K>
ClusterShape_MNK cluster_shape, // (ClusterM, ClusterN, ClusterK)
Alpha alpha, Beta beta)
{
// Step 1: The Prologue.
// The CTA layout within the Cluster: (V,M,N,K) -> CTA idx
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename TiledMMA::AtomThrID{}));
// Construct the MMA grid coordinate from the CTA grid coordinate
auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), // Peer CTA coordinate
blockIdx.x / size<0>(cluster_layout_vmnk), // MMA-M coordinate
blockIdx.y, // MMA-N coordinate
_); // MMA-K coordinate
// Partition the GMEM tensors with the mma_tiler and mma_coord to get the slices processed
// by this mma tile.
// CuTe provides local_tile partitioning function. local_tile accepts 4 parameters:
// * Tensor to partition
// * Tiler to use for partitioning
// * Coordinate to use for slicing the partitioned tensor
// * Projection to ignore unwanted modes of the Tiler and Coordinate
auto mma_coord = select<1,2,3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X,_1>{}); // (MmaTile_M, MmaTile_K, Tiles_K)
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step< X,_1,_1>{}); // (MmaTile_N, MmaTile_K, Tiles_K)
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1,_1, X>{}); // (MmaTile_M, MmaTile_N)
if (thread0()) {
print("mA:\t"); print(mA); print("\n"); // mA: gmem_ptr[16b](GMEM_ADDR_A) o (512,256):(256,_1)
print("mB:\t"); print(mB); print("\n"); // mB: gmem_ptr[16b](GMEM_ADDR_B) o (1024,256):(256,_1)
print("mC:\t"); print(mC); print("\n"); // mC: gmem_ptr[32b](GMEM_ADDR_C) o (512,1024):(1024,_1)
print("mD:\t"); print(mD); print("\n"); // mD: gmem_ptr[32b](GMEM_ADDR_D) o (512,1024):(1024,_1)
print("gA:\t"); print(gA); print("\n"); // gA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile) o (_128,_64,4):(256,_1,_64)
print("gB:\t"); print(gB); print("\n"); // gB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile) o (_256,_64,4):(_1,256,16384)
print("gC:\t"); print(gC); print("\n"); // gC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile) o (_128,_256):(256,_1)
print("gD:\t"); print(gD); print("\n"); // gD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile) o (_128,_256):(256,_1)
} __syncthreads();
// The SMEM tensors
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
//
// Mma partitioning for A and B
//
// Note: Partitioned tensors use tXgY naming convention:
// tXgY -> The partitioning pattern tX applied to tensor gY
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v); // Use Peer CTA coordinate
Tensor tCgA = cta_mma.partition_A(gA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCgB = cta_mma.partition_B(gB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)
Tensor tCgC = cta_mma.partition_C(gC); // (MmaC, NumMma_M, NumMma_N)
Tensor tCgD = cta_mma.partition_C(gD); // (MmaC, NumMma_M, NumMma_N)
if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: gmem_ptr[16b](GMEM_ADDR_A + offset_for_mma_tile + offset_for_mma) o ((_128,_16),_1,_4,4):((256,_1),_0,_16,_64)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: gmem_ptr[16b](GMEM_ADDR_B + offset_for_mma_tile + offset_for_mma) o ((_256,_16),_1,_4,4):((_1,256),_0,4096,16384)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
// MMA Fragment Allocation
// We allocate "fragments" which are SMEM descriptors that serve as inputs to cute::gemm operations.
// For tcgen05.mma operations:
// - Matrices A and B are sourced from SMEM
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
// ThrMma's make_fragment_C() creates a TMEM tensor with the appropriate layout for the accumulator.
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); // (MmaC, NumMma_M, NumMma_N)
uint32_t elect_one_thr = cute::elect_one_sync();
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
using TmemAllocator = cute::TMEM::Allocator1Sm;
TmemAllocator tmem_allocator{};
if (elect_one_warp) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
}
__syncthreads(); // Wait for all threads until warp0 allocates TMEM
tCtAcc.data() = shared_storage.tmem_base_ptr;
if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
} __syncthreads();
// Barrier Initialization
// Barriers in SMEM initialized by a single thread.
if (elect_one_warp && elect_one_thr) {
cute::initialize_barrier(shared_storage.mma_barrier, /* num_ctas */ 1);
}
int mma_barrier_phase_bit = 0; // Each barrier has an associated phase_bit.
__syncthreads(); // Make sure all threads observe barrier initialization.
// Step 2: The Mainloop.
// Set mma accumulate option to zero so that the first MMA instruction will clear the TMEM accumulator.
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Execute a MmaTile_M x MmaTile_N x GEMM_K GEMM
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile)
{
// Step 2a: Load A and B tiles
// Using auto-vectorized copy operation:
// - Utilizes 128 threads for parallel data transfer
// - Copy operations are distributed efficiently across all threads
// - CuTe can automatically determine optimal vector width
cooperative_copy<128>(threadIdx.x, tCgA(_,_,_,k_tile), tCsA); // Load MmaTile_M x MmaTile_K A tile
cooperative_copy<128>(threadIdx.x, tCgB(_,_,_,k_tile), tCsB); // Load MmaTile_N x MmaTile_K B tile
// Step 2b: Execute the MMAs for this tile
// Wait for loads to SMEM to complete with __syncthreads()
__syncthreads();
// tcgen05.mma instructions require single-thread execution:
// - Only one warp performs the MMA-related loop operations
// - CuTe operations internally manage the single-thread execution of tcgen05.mma and tcgen05.cp
// - No explicit elect_one_sync region is needed from the user
if (elect_one_warp) {
// Execute a MmaTile_M x MmaTile_N x MmaTile_K GEMM
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
// Ensure MMAs are completed, only then we can reuse the A and B SMEM.
cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
}
// Wait MMAs to complete to avoid overwriting the A and B SMEM.
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue.
// Create the tiled copy operation for the accumulator (TMEM -> RMEM)
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
Tensor tDgC = thr_t2r_copy.partition_D(tCgC); // (CpyD, NumCpy_M, NumCpy_N)
Tensor tDrC = make_fragment_like(tDgC); // (CpyD, NumCpy_M, NumCpy_N)
// Load C tensor GMEM -> RMEM
copy(tDgC, tDrC);
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); // (CpyS, NumCpy_M, NumCpy_N)
Tensor tDgD = thr_t2r_copy.partition_D(tCgD); // (CpyD, NumCpy_M, NumCpy_N)
using AccType = typename decltype(tCtAcc)::value_type;
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD)); // (CpyD, NumCpy_M, NumCpy_N)
// Load TMEM -> RMEM
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
// AXPBY RMEM -> RMEM: tDrC = alpha * tDrAcc + beta * tDrC
axpby(alpha, tDrAcc, beta, tDrC);
// Store RMEM -> GMEM
copy(tDrC, tDgD);
__syncthreads();
// Release the right to allocate before deallocations so that the next CTA can rasterize
// Then deallocate TMEM
if (elect_one_warp) {
tmem_allocator.release_allocation_lock();
tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
template <class TypeA, class LayoutA,
class TypeB, class LayoutB,
class TypeC, class LayoutC,
class TypeD, class LayoutD,
class Alpha, class Beta>
void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,
TypeB const* device_ptr_B, LayoutB layout_B,
TypeC const* device_ptr_C, LayoutC layout_C,
TypeD * device_ptr_D, LayoutD layout_D,
Alpha const alpha, Beta const beta)
{
assert(shape<0>(layout_A) == shape<0>(layout_C)); // Gemm_M
assert(shape<0>(layout_A) == shape<0>(layout_D)); // Gemm_M
assert(shape<0>(layout_B) == shape<1>(layout_C)); // Gemm_N
assert(shape<0>(layout_B) == shape<1>(layout_D)); // Gemm_N
assert(shape<1>(layout_A) == shape<1>(layout_B)); // Gemm_K
// Represent the full tensors in global memory
Tensor mA = make_tensor(make_gmem_ptr(device_ptr_A), layout_A); // (Gemm_M, Gemm_K)
Tensor mB = make_tensor(make_gmem_ptr(device_ptr_B), layout_B); // (Gemm_N, Gemm_K)
Tensor mC = make_tensor(make_gmem_ptr(device_ptr_C), layout_C); // (Gemm_M, Gemm_N)
Tensor mD = make_tensor(make_gmem_ptr(device_ptr_D), layout_D); // (Gemm_M, Gemm_N)
// Get M, N, K dimensions of the GEMM we are running
auto Gemm_M = shape<0>(layout_A);
auto Gemm_N = shape<0>(layout_B);
auto Gemm_K = shape<1>(layout_A);
std::cout << "Running for problem shape (MxNxK): " << Gemm_M << "x" << Gemm_N << "x" << Gemm_K << std::endl;
////////////////////////////////////////////////////////////
//
// Initialize the GEMM kernel parameters
//
////////////////////////////////////////////////////////////
// Create TiledMma. make_tiled_mma takes the target instructions and an (optional) instruction layout as parameters to create a
// larger TiledMma from the given mma instruction.
// See cute/arch/mma_sm100_umma.hpp for all tcgen05.mma instructions
TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, // Mma's A, B, and Accumulator types
128, 256, // Mma M and N dimensions
UMMA::Major::K, UMMA::Major::K>{}); // A and B layouts
// We can also print and inspect the tiled_mma
print(tiled_mma);
// TiledMMA
// ThrLayoutVMNK: (_1,_1,_1,_1):(_0,_0,_0,_0)
// PermutationMNK: (_,_,_)
// MMA_Atom
// ThrID: _1:_0
// Shape_MNK: (_128,_256,_16) // MmaM, MmaN, MmaK instruction size
// LayoutA_TV: (_1,(_128,_16)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for A matrix
// LayoutB_TV: (_1,(_256,_16)):(_0,(_1,_256)) // TV -> MmaCoordinate mapping for B matrix
// LayoutC_TV: (_1,(_128,_256)):(_0,(_1,_128)) // TV -> MmaCoordinate mapping for C matrix
// Define MMA tiler sizes (static)
auto bM = tile_size<0>(tiled_mma); // MMA Tile M. We'll use 1 MMAs per MMA Tile M.
auto bN = tile_size<1>(tiled_mma); // MMA Tile N. We'll use 1 MMAs per MMA Tile M.
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // MMA Tile K. We'll use 4 MMAs per MMA Tile K. For 16b types, tcgen05.mma has K16.
auto mma_tiler = make_shape(bM, bN, bK); // (MMA_M, MMA_N, MMA_K)
// In SM90, the MMAs are CTA-local and perform thread-level partitioning.
// In SM100, the MMAs are Cluster-local and perform CTA-level partitioning.
// Thus, SM90 uses a cta_tiler to extract portions of the Problem for the CTA
// and SM100 uses a mma_tiler to extract portions of the Problem for the MMA.
// The MMA's partitioning then yields the CTA-local work.
if (not evenly_divides(shape(mma_tiler), tile_shape(tiled_mma))) {
std::cerr << "The MMA Shape should evenly divide the MMA Tiler." << std::endl;
return;
}
if (not evenly_divides(make_shape(Gemm_M, Gemm_N, Gemm_K), mma_tiler)) {
std::cerr << "OOB accesses are not supported. MmaTiler_MNK should evenly divide ProblemShape_MNK." << std::endl;
return;
}
//
// Determine the SMEM layouts:
//
// * SMEM layouts for A and B must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.
// * CuTe provides partition_shape_[A|B] functions to determine the post-partitioned shape.
// These functions take the TiledMma, and the MMA Tile Shape as inputs and returns a shape that is at least rank-3
// where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time
// MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.
// * Note that SMEM layouts are needed to determine SMEM allocation for kernel launch.
// Pre-partitioned Tile Shape (MmaTile_M, MmaTile_K) to post-partitioned (MmaA, NumMma_M, NumMma_K)
auto mma_shape_A = partition_shape_A(tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
// Pre-partitioned Tile Shape (MmaTile_N, MmaTile_K) to post-partitioned (MmaB, NumMma_N, NumMma_K)
auto mma_shape_B = partition_shape_B(tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
// * CuTe provides tile_to_mma_shape functions for SM100 to create swizzled layouts for post-partitioned Mma Shapes
auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
// The cluster shape and layout
auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
Layout cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape),
make_tile(typename decltype(tiled_mma)::AtomThrID{}));
////////////////////////////////////////////////////////////
//
// Launch GEMM kernel
//
////////////////////////////////////////////////////////////
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x),
round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y));
int smemBytes = sizeof(SMEMStorage);
auto* kernel_ptr = &gemm_device<SMEMStorage,
decltype(mA), decltype(mB), decltype(mC), decltype(mD),
decltype(mma_tiler), decltype(tiled_mma), decltype(cluster_shape),
Alpha, Beta>;
// Set kernel attributes (set SMEM)
CUTE_CHECK_ERROR(cudaFuncSetAttribute(kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smemBytes));
printf("Grid launched: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
printf("Cluster launched: %d, %d, %d\n", dimCluster.x, dimCluster.y, dimCluster.z);
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(params, (void const*) kernel_ptr,
mA, mB, mC, mD,
mma_tiler, tiled_mma, cluster_shape,
alpha, beta);
CUTE_CHECK_LAST();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Error: Failed at kernel Launch" << std::endl;
}
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int main(int argc, char** argv)
{
cudaDeviceProp props;
int current_device_id;
cudaGetDevice(&current_device_id);
cudaGetDeviceProperties(&props, current_device_id);
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if ((props.major != 10) || (props.major == 10 && props.minor > 1)) {
std::cerr << "This example requires NVIDIA's Blackwell Architecture GPU with compute capability 100a." << std::endl;
std::cerr << " Found " << props.major << "." << props.minor << std::endl;
return -1;
}
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
int Gemm_M = 512;
if (argc >= 2)
sscanf(argv[1], "%d", &Gemm_M);
int Gemm_N = 1024;
if (argc >= 3)
sscanf(argv[2], "%d", &Gemm_N);
int Gemm_K = 256;
if (argc >= 4)
sscanf(argv[3], "%d", &Gemm_K);
////////////////////////////////////////////////////////////
//
// Create A, B, C, and D tensors
//
////////////////////////////////////////////////////////////
// Define the data types. A and B types are same for MMA instruction.
using TypeA = cutlass::half_t; // MMA A Data Type
auto type_str_a = "half_t";
using TypeB = cutlass::half_t; // MMA B Data Type
auto type_str_b = "half_t";
using TypeC = float; // MMA C Data Type
[[maybe_unused]] auto type_str_c = "float";
using TypeD = float; // MMA D Data Type
auto type_str_d = "float";
using TypeAccumulator = float; // Both TypeC and TypeD are float, use float accumulator type.
// A tensor MxK K-major (Layout T = Row-Major)
Layout layout_A = make_layout(make_shape (Gemm_M, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_M,Gemm_K):(Gemm_K,_1)
// B tensor NxK K-major (Layout N = Column-Major)
Layout layout_B = make_layout(make_shape (Gemm_N, Gemm_K),
make_stride(Gemm_K, Int<1>{})); // (Gemm_N,Gemm_K):(Gemm_K,_1)
// C tensor MxN N-major (Layout T = Row-Major)
Layout layout_C = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// D tensor MxN N-major (Layout T = Row-Major)
Layout layout_D = make_layout(make_shape (Gemm_M, Gemm_N),
make_stride(Gemm_N, Int<1>{})); // (Gemm_M,Gemm_N):(Gemm_N,_1)
// Host allocations and host CuTe tensors for A, B, and C tensors.
thrust::host_vector<TypeA> host_A(Gemm_M * Gemm_K);
Tensor host_tensor_A = make_tensor(host_A.data(), layout_A);
print("host_tensor_A:\t"); print(host_tensor_A); print("\n"); // host_tensor_A: ptr[16b](ADDR_A) o (512,256):(256,_1)
thrust::host_vector<TypeB> host_B(Gemm_N * Gemm_K);
Tensor host_tensor_B = make_tensor(host_B.data(), layout_B);
print("host_tensor_B:\t"); print(host_tensor_B); print("\n"); // host_tensor_B: ptr[16b](ADDR_B) o (1024,256):(256,_1)
thrust::host_vector<TypeC> host_C(Gemm_M * Gemm_N);
Tensor host_tensor_C = make_tensor(host_C.data(), layout_C);
print("host_tensor_C:\t"); print(host_tensor_C); print("\n"); // host_tensor_C: ptr[32b](ADDR_C) o (512,1024):(1024,_1)
// Note that we don't need a host_tensor for D yet.
thrust::device_vector<TypeD> device_D(Gemm_M * Gemm_N);
// Initialize A, B, and C tensors with random values.
initialize_tensor(host_tensor_A);
initialize_tensor(host_tensor_B);
initialize_tensor(host_tensor_C);
// Copy A, B, and C tensors from host memory to device memory
thrust::device_vector<TypeA> device_A = host_A;
thrust::device_vector<TypeB> device_B = host_B;
thrust::device_vector<TypeC> device_C = host_C;
using Alpha = float;
using Beta = float;
Alpha alpha = 1.0f;
Beta beta = 0.0f;
// Setup input and output tensors, and the kernel parameters; and execute the kernel on device
gemm_host_f16xf16_f32_f32_tnt(device_A.data().get(), layout_A,
device_B.data().get(), layout_B,
device_C.data().get(), layout_C,
device_D.data().get(), layout_D,
alpha, beta);
// Host allocation for D tensor and transfer D tensor from device to host
thrust::host_vector<TypeD> host_D = device_D;
// Create a non-owning CuTe tensor for D tensor
Tensor host_tensor_D = make_tensor(host_D.data(), layout_D);
////////////////////////////////////////////////////////////
//
// Execute reference GEMM kernel
//
////////////////////////////////////////////////////////////
thrust::host_vector<TypeD> host_reference_D(Gemm_M*Gemm_N);
auto host_reference_tensor_D = make_tensor(host_reference_D.data(), layout_D);
reference_gemm<TypeAccumulator>(host_tensor_A, host_tensor_B, host_tensor_C, host_reference_tensor_D, alpha, beta);
////////////////////////////////////////////////////////////
//
// Compare results
//
////////////////////////////////////////////////////////////
auto relative_error = print_matrix_multiply_mollified_relative_error(type_str_a, host_tensor_A,
type_str_b, host_tensor_B,
type_str_d, host_tensor_D, host_reference_tensor_D);
bool success = relative_error <= 0.0;
std::cout << "Execution is " << ((success) ? "successful." : "failed.") << std::endl;
#else
std::cout << "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl;
#endif
return 0;
}