1013 lines
42 KiB
Python
1013 lines
42 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
import math
|
|
import time
|
|
from typing import Tuple, Type
|
|
|
|
import cuda.bindings.driver as cuda
|
|
import torch
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
"""
|
|
A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL.
|
|
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
|
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
|
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
|
|
|
This GEMM kernel supports the following features:
|
|
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
|
- Threadblock rasterization to improve data re-use
|
|
- Supports multi-stage pipeline to overlap computation and memory access
|
|
- Implements shared memory buffering for epilogue to increase coalesced global memory access
|
|
|
|
This GEMM works as follows:
|
|
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
|
2. Perform matrix multiply-accumulate (MMA) operations.
|
|
3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM).
|
|
|
|
The Ampere tensor core instruction used operates as follows:
|
|
- Read matrix A from SMEM
|
|
- Read matrix B from SMEM
|
|
- Perform MMA operation and store the result in Accumulator(register)
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/ampere/tensorop_gemm.py \
|
|
--mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \
|
|
--ab_dtype Float16 \
|
|
--c_dtype Float16 --acc_dtype Float32 \
|
|
--a_major m --b_major n --c_major n
|
|
|
|
The above example command computes with M=8192, N=8192, K=8192,
|
|
batch_count=1. The atom layout's shape is 2x2x1 and the input, mma
|
|
accumulator, and output data type are set as fp16, fp32 and fp16,
|
|
respectively.
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/ampere/tensorop_gemm.py \
|
|
--mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \
|
|
--ab_dtype Float16 \
|
|
--c_dtype Float16 --acc_dtype Float32 \
|
|
--a_major m --b_major n --c_major n \
|
|
--skip_ref_check --iterations 2
|
|
|
|
Constraints:
|
|
* Supported input and output data types: fp16
|
|
* Support accumulator data types: f32
|
|
* Default tile shape is set to be 128x128x32
|
|
* Atom layout's MNK shape is set so that tile shape can be divided by MMA
|
|
instruction shape
|
|
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
i.e, number of elements is a multiple of 8
|
|
"""
|
|
|
|
|
|
class TensorOpGemm:
|
|
def __init__(
|
|
self,
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
atom_layout_mnk: Tuple[int, int, int],
|
|
):
|
|
self.ab_dtype = ab_dtype
|
|
self.c_dtype = c_dtype
|
|
self.acc_dtype = acc_dtype
|
|
self.cta_tiler = (128, 128, 32)
|
|
self.num_stages = 3
|
|
self.atom_layout_mnk = atom_layout_mnk
|
|
atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
|
|
self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
|
|
|
|
self.bM, self.bN, self.bK = self.cta_tiler
|
|
self.mma_inst_shape = (16, 8, 16)
|
|
mmaM, mmaN, mmaK = self.mma_inst_shape
|
|
|
|
assert (
|
|
self.bM % (atom_lay_M * mmaM) == 0
|
|
), "bM must be divisible by MMA instruction"
|
|
assert (
|
|
self.bN % (atom_lay_N * mmaN) == 0
|
|
), "bN must be divisible by MMA instruction"
|
|
assert atom_lay_K == 1, "this example does not support atom layout K > 1"
|
|
assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction"
|
|
assert self.num_stages >= 3, "num_stages must be greater than or equal to 3"
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
mA: cute.Tensor,
|
|
mB: cute.Tensor,
|
|
mC: cute.Tensor,
|
|
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
):
|
|
# The grid divides the problems's M, N, and L dimensions by the
|
|
# respective modes of the tile shape (bM, bN, 1). The K dimension is
|
|
# handled within a block via a multistage process.
|
|
|
|
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
|
|
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
|
|
self.c_major_mode = utils.LayoutEnum.from_tensor(mC)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Shared memory layout:
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
# Creates a layout with the size required for the provided tile
|
|
# size and num stages (stages are used for K dimension) that is also
|
|
# sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that
|
|
# the atom for the shared memory -> register copy does not encounter
|
|
# bank conflicts
|
|
|
|
# assume the input is 16B align
|
|
ab_copy_bits = 128
|
|
sA_layout = self._make_smem_layout_AB(
|
|
mA.element_type,
|
|
self.a_major_mode,
|
|
ab_copy_bits,
|
|
(self.cta_tiler[0], self.cta_tiler[2], self.num_stages),
|
|
)
|
|
sB_layout = self._make_smem_layout_AB(
|
|
mB.element_type,
|
|
self.b_major_mode,
|
|
ab_copy_bits,
|
|
(self.cta_tiler[1], self.cta_tiler[2], self.num_stages),
|
|
)
|
|
|
|
# Creates a similar layout but without num_stages or layout atoms
|
|
sC_layout = self._make_smem_layout_C(
|
|
mC.element_type,
|
|
self.c_major_mode,
|
|
ab_copy_bits,
|
|
(self.cta_tiler[0], self.cta_tiler[1]),
|
|
)
|
|
|
|
# Shared memory allocated for operations with A, B will be
|
|
# overwritten for operations on C. This is to improve performance
|
|
# by reducing the size of shared memory requested by each block
|
|
smem_size = max(
|
|
cute.size_in_bytes(mC.element_type, sC_layout),
|
|
cute.size_in_bytes(mA.element_type, sA_layout)
|
|
+ cute.size_in_bytes(mB.element_type, sB_layout),
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Tiled copy:
|
|
# The majorness of tA/tB/tC follows the majorness of gA/gB/gC,
|
|
# enabling merged accesses to global memory for faster data
|
|
# transfer between global and shared memory.
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
# Create a copy atom for a global to shared memory asynchronous copy
|
|
atom_async_copy = cute.make_copy_atom(
|
|
cute.nvgpu.cpasync.CopyG2SOp(
|
|
cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
|
|
),
|
|
mA.element_type,
|
|
num_bits_per_copy=ab_copy_bits,
|
|
)
|
|
|
|
# Create thread layouts for tiled copy from the copy atom where the
|
|
# thread layout simply follows the leading dimension of the tensor
|
|
tiled_copy_A = self._make_gmem_tiled_copy_AB(
|
|
atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits
|
|
)
|
|
tiled_copy_B = self._make_gmem_tiled_copy_AB(
|
|
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
|
|
)
|
|
|
|
# Creates a synchronous copy atom and thread layouts for the epilogue
|
|
c_copy_bits = 128
|
|
atom_sync_copy = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
mC.element_type,
|
|
num_bits_per_copy=c_copy_bits,
|
|
)
|
|
tiled_copy_C = self._make_gmem_tiled_copy_C(
|
|
atom_sync_copy, mC.element_type, self.c_major_mode, c_copy_bits
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Tiled MMA
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
# Creates a mma atom with 16x8x16 shape for MNK
|
|
op = cute.nvgpu.warp.MmaF16BF16Op(
|
|
self.ab_dtype, self.acc_dtype, self.mma_inst_shape
|
|
)
|
|
|
|
permutation_mnk = (
|
|
self.atom_layout_mnk[0] * self.mma_inst_shape[0],
|
|
# if atom layout's N-mode is 1, to leverage the largest coalesced
|
|
# shared memory -> register copy, set the tiled mma's N mode to 16
|
|
self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
|
|
self.atom_layout_mnk[2] * self.mma_inst_shape[2],
|
|
)
|
|
|
|
# Created a tiled mma that tiles the atom according to specified layout.
|
|
# For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice
|
|
# across M and twice across N
|
|
tC = cute.make_layout(self.atom_layout_mnk)
|
|
tiled_mma = cute.make_tiled_mma(
|
|
op,
|
|
tC,
|
|
permutation_mnk=permutation_mnk,
|
|
)
|
|
|
|
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
|
|
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
|
|
|
# Add threadblock rasterization to improve re-use of data
|
|
raster_factor = 1
|
|
grid_dim_n = cute.size(grid_dim[1])
|
|
# Thresholds picked so that it doesn't cause too many no-op CTAs
|
|
if grid_dim_n > 5:
|
|
raster_factor = 8
|
|
elif grid_dim_n > 2:
|
|
raster_factor = 4
|
|
elif grid_dim_n > 1:
|
|
raster_factor = 2
|
|
rasterization_remap_grid_dim = (
|
|
cute.size(grid_dim[0]) * raster_factor,
|
|
(cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor,
|
|
cute.size(grid_dim[2]),
|
|
)
|
|
|
|
self.kernel(
|
|
mA,
|
|
mB,
|
|
mC,
|
|
sA_layout,
|
|
sB_layout,
|
|
sC_layout,
|
|
tiled_copy_A,
|
|
tiled_copy_B,
|
|
tiled_copy_C,
|
|
tiled_mma,
|
|
raster_factor,
|
|
epilogue_op,
|
|
).launch(
|
|
grid=rasterization_remap_grid_dim,
|
|
block=[self.num_threads, 1, 1],
|
|
smem=smem_size,
|
|
)
|
|
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
mA: cute.Tensor,
|
|
mB: cute.Tensor,
|
|
mC: cute.Tensor,
|
|
sA_layout: cute.ComposedLayout,
|
|
sB_layout: cute.ComposedLayout,
|
|
sC_layout: cute.ComposedLayout,
|
|
tiled_copy_A: cute.TiledCopy,
|
|
tiled_copy_B: cute.TiledCopy,
|
|
tiled_copy_C: cute.TiledCopy,
|
|
tiled_mma: cute.TiledMma,
|
|
rasterization_factor: cutlass.Int32,
|
|
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
):
|
|
# Thread index, block index
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
|
offset_tile_x, offset_tile_y = self.raster_tile(
|
|
bidx, bidy, rasterization_factor
|
|
)
|
|
# Early exit if CTA is out of range
|
|
if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y:
|
|
pass
|
|
else:
|
|
tiler_coord = (offset_tile_x, offset_tile_y, None)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Get the appropriate tiles for this thread block.
|
|
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
gA = cute.local_tile(
|
|
mA[None, None, bidz],
|
|
tiler=self.cta_tiler,
|
|
coord=tiler_coord,
|
|
proj=(1, None, 1),
|
|
)
|
|
gB = cute.local_tile(
|
|
mB[None, None, bidz],
|
|
tiler=self.cta_tiler,
|
|
coord=tiler_coord,
|
|
proj=(None, 1, 1),
|
|
)
|
|
gC = cute.local_tile(
|
|
mC[None, None, bidz],
|
|
tiler=self.cta_tiler,
|
|
coord=tiler_coord,
|
|
proj=(1, 1, None),
|
|
)
|
|
|
|
# By default, if the tensor k mode does not divide into the tile k
|
|
# size, then last tiles in the k dimension are irregular.
|
|
# Instead, make the first tiles irregular when k is irregular.
|
|
# This allows us to handle the irregular tile first to avoid
|
|
# checking for this condition within the mainloop.
|
|
|
|
# residual_k is a negative number indicating the amount needed to
|
|
# shift the pointer by in dimension k
|
|
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
|
|
gA, mode=[2]
|
|
)
|
|
|
|
# move the pointer of gA/gB in the `-k` direction
|
|
gA = cute.domain_offset((0, residual_k, 0), gA)
|
|
gB = cute.domain_offset((0, residual_k, 0), gB)
|
|
# input is 16B aligned
|
|
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
|
|
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
|
|
|
|
# Construct identity layout for sA and sB (mirrors global tensors,
|
|
# used for predication only)
|
|
mcA = cute.make_identity_tensor(mA.layout.shape)
|
|
mcB = cute.make_identity_tensor(mB.layout.shape)
|
|
cA = cute.local_tile(
|
|
mcA[None, None, bidz],
|
|
tiler=self.cta_tiler,
|
|
coord=tiler_coord,
|
|
proj=(1, None, 1),
|
|
)
|
|
cB = cute.local_tile(
|
|
mcB[None, None, bidz],
|
|
tiler=self.cta_tiler,
|
|
coord=tiler_coord,
|
|
proj=(None, 1, 1),
|
|
)
|
|
|
|
cA = cute.domain_offset((0, residual_k, 0), cA)
|
|
cB = cute.domain_offset((0, residual_k, 0), cB)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Create shared memory buffers and get the appropriate fragments for this thread.
|
|
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
|
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
|
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Shared memory buffer
|
|
smem = cutlass.utils.SmemAllocator()
|
|
|
|
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
|
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
|
sC = cute.make_tensor(
|
|
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
|
|
)
|
|
|
|
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
|
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
|
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
|
tAgA = thr_copy_A.partition_S(gA)
|
|
tAsA = thr_copy_A.partition_D(sA)
|
|
tBgB = thr_copy_B.partition_S(gB)
|
|
tBsB = thr_copy_B.partition_D(sB)
|
|
tCsC_epilogue = thr_copy_C.partition_S(sC)
|
|
tCgC_epilogue = thr_copy_C.partition_D(gC)
|
|
|
|
# Repeat the partitioning with identity layouts
|
|
tAcA = thr_copy_A.partition_S(cA)
|
|
tBcB = thr_copy_B.partition_S(cB)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
|
# of tile_shape
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
# For predication over the tensors A (M/K), B (N/K), and (in the
|
|
# epilogue) C (M/N), we will compute it in a fashion similar to an
|
|
# outer product. The predication along one of the dimensions is
|
|
# evaluated and stored in a predication tensor. Then, the
|
|
# predication for the remaining dimension is handled later via an
|
|
# if/else branch at the copy.
|
|
# For A and B, predication booleans along M/N are stored in a
|
|
# predication tensor and along K is handled via a if/else branch.
|
|
|
|
# Allocate predicate tensors for M and N. Predication is checked
|
|
# at the granularity of a copy atom, so the predicate tensor does not
|
|
# need separate booleans for individual elements within a copy
|
|
# atom (for example, the elements of tAgA.shape[0][0].)
|
|
tApA = cute.make_fragment(
|
|
cute.make_layout(
|
|
(
|
|
tAgA.shape[0][1],
|
|
cute.size(tAgA, mode=[1]),
|
|
cute.size(tAgA, mode=[2]),
|
|
),
|
|
stride=(cute.size(tAgA, mode=[1]), 1, 0),
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
tBpB = cute.make_fragment(
|
|
cute.make_layout(
|
|
(
|
|
tBsB.shape[0][1],
|
|
cute.size(tBsB, mode=[1]),
|
|
cute.size(tBsB, mode=[2]),
|
|
),
|
|
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
# Set predicates for M/N bounds
|
|
for rest_v in range(tApA.shape[0]):
|
|
for m in range(tApA.shape[1]):
|
|
tApA[rest_v, m, 0] = cute.elem_less(
|
|
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
|
)
|
|
for rest_v in range(tBpB.shape[0]):
|
|
for n in range(tBpB.shape[1]):
|
|
tBpB[rest_v, n, 0] = cute.elem_less(
|
|
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Prefetch Prologue
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Clear the smem tiles to account for predicated off loads
|
|
tAsA.fill(0)
|
|
tBsB.fill(0)
|
|
cute.arch.sync_threads()
|
|
# Start async loads for the first k-tile. Here we take care of the k residue
|
|
# via if/else check along the k dimension. Because we shifted the identity tensor
|
|
# by the residue_k and because the identity tensor is a counting tensor, the
|
|
# values of any identity tensor element that is poison is less than -1
|
|
num_smem_stages = cute.size(tAsA, mode=[3])
|
|
k_tile_count = cute.size(tAgA, mode=[3])
|
|
k_tile_index = cutlass.Int32(0)
|
|
|
|
for k in range(tApA.shape[2]):
|
|
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
|
|
cute.copy(
|
|
tiled_copy_A,
|
|
tAgA[None, None, k, k_tile_index],
|
|
tAsA[None, None, k, 0],
|
|
pred=tApA[None, None, k],
|
|
)
|
|
for k in range(tBpB.shape[2]):
|
|
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
|
|
cute.copy(
|
|
tiled_copy_B,
|
|
tBgB[None, None, k, k_tile_index],
|
|
tBsB[None, None, k, 0],
|
|
pred=tBpB[None, None, k],
|
|
)
|
|
k_tile_index = k_tile_index + 1
|
|
cute.arch.cp_async_commit_group()
|
|
|
|
# Start async loads for rest of the k-tiles
|
|
for k_tile in range(1, num_smem_stages - 1):
|
|
if k_tile == k_tile_count:
|
|
tApA.fill(0)
|
|
tBpB.fill(0)
|
|
cute.copy(
|
|
tiled_copy_A,
|
|
tAgA[None, None, None, k_tile_index],
|
|
tAsA[None, None, None, k_tile],
|
|
pred=tApA,
|
|
)
|
|
cute.copy(
|
|
tiled_copy_B,
|
|
tBgB[None, None, None, k_tile_index],
|
|
tBsB[None, None, None, k_tile],
|
|
pred=tBpB,
|
|
)
|
|
k_tile_index = k_tile_index + 1
|
|
cute.arch.cp_async_commit_group()
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Tile MMA compute thread partitions and allocate accumulators
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
thr_mma = tiled_mma.get_slice(tidx)
|
|
tCsA = thr_mma.partition_A(sA)
|
|
tCsB = thr_mma.partition_B(sB)
|
|
tCsC = thr_mma.partition_C(sC)
|
|
tCgC = thr_mma.partition_C(gC)
|
|
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
|
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
|
tCrC = tiled_mma.make_fragment_C(tCgC)
|
|
# Clear the accumulator
|
|
tCrC.fill(0.0)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Copy Atom A/B retiling
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
# Create the copy atoms for the copy from shared memory to register
|
|
atom_copy_s2r_A = cute.make_copy_atom(
|
|
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
|
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
|
),
|
|
mA.element_type,
|
|
)
|
|
atom_copy_s2r_B = cute.make_copy_atom(
|
|
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
|
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
|
),
|
|
mB.element_type,
|
|
)
|
|
|
|
# Creates the tiled copy so that it matches the thread-value layout
|
|
# expected by the tiled mma
|
|
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
|
|
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
|
|
|
|
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
|
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
|
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
|
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
|
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
|
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
|
|
|
# Current pipe index in smem to read from / write to
|
|
smem_pipe_read = 0
|
|
smem_pipe_write = num_smem_stages - 1
|
|
|
|
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
|
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# PREFETCH register pipeline
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
num_k_block = cute.size(tCrA, mode=[2])
|
|
if num_k_block > 1:
|
|
# Wait until our first prefetched tile is loaded in
|
|
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
|
cute.arch.sync_threads()
|
|
# Prefetch the first k-block rmem from the first k-tile
|
|
cute.copy(
|
|
tiled_copy_s2r_A,
|
|
tCsA_p[None, None, 0],
|
|
tCrA_copy_view[None, None, 0],
|
|
)
|
|
cute.copy(
|
|
tiled_copy_s2r_B,
|
|
tCsB_p[None, None, 0],
|
|
tCrB_copy_view[None, None, 0],
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Mainloop
|
|
# 1. Shared memory pipeline (gmem -> smem):
|
|
# The default smem pipeline depth is 3, meaning that for shared
|
|
# memory buffers, we allocate three times the size described by the
|
|
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
|
# loop. Considering only the transfer from global memory to shared
|
|
# memory, the general structure of the mainloop is:
|
|
# (1) copy k-tile from gmem to smem;
|
|
# (2) perform gemm computation on k-tile;
|
|
# (3) wait for the next copy to finish.
|
|
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
|
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
|
# of this approach is that it allows for simultaneous production
|
|
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
|
# A common misconception is to prefetch N buffers and rewrite
|
|
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
|
# of this approach is that it requires fully consuming a buffer in
|
|
# order to open an empty buffer for the next copy.
|
|
# 2. Register pipeline (smem -> register):
|
|
# Similarly, the register pipeline produces i+1, consumes i, and
|
|
# produces i+2... Notably, i and i+1 do not use the same register,
|
|
# eliminating dependencies on the same register for better parallelism.
|
|
# 3. Combining the smem and register pipelines results in the mainloop.
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
for k_tile in range(k_tile_count):
|
|
for k_block in cutlass.range(num_k_block, unroll_full=True):
|
|
if k_block == num_k_block - 1:
|
|
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
|
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
|
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
|
cute.arch.sync_threads()
|
|
|
|
# Load A, B from shared memory to registers for k_block + 1
|
|
k_block_next = (k_block + 1) % num_k_block # static
|
|
cute.copy(
|
|
tiled_copy_s2r_A,
|
|
tCsA_p[None, None, k_block_next],
|
|
tCrA_copy_view[None, None, k_block_next],
|
|
)
|
|
cute.copy(
|
|
tiled_copy_s2r_B,
|
|
tCsB_p[None, None, k_block_next],
|
|
tCrB_copy_view[None, None, k_block_next],
|
|
)
|
|
|
|
# Fetch next A: To better interleave global memory access and compute
|
|
# instructions, we intentionally use the sequence: copy A, perform GEMM,
|
|
# then copy B.
|
|
if k_block == 0:
|
|
if k_tile + num_smem_stages - 1 < k_tile_count:
|
|
cute.copy(
|
|
tiled_copy_A,
|
|
tAgA[None, None, None, k_tile_index],
|
|
tAsA[None, None, None, smem_pipe_write],
|
|
pred=tApA,
|
|
)
|
|
|
|
# Thread-level register gemm for k_block
|
|
cute.gemm(
|
|
tiled_mma,
|
|
tCrC,
|
|
tCrA[None, None, k_block],
|
|
tCrB[None, None, k_block],
|
|
tCrC,
|
|
)
|
|
|
|
# Fetch next B and update smem pipeline read/write
|
|
if k_block == 0:
|
|
if k_tile + num_smem_stages - 1 < k_tile_count:
|
|
cute.copy(
|
|
tiled_copy_B,
|
|
tBgB[None, None, None, k_tile_index],
|
|
tBsB[None, None, None, smem_pipe_write],
|
|
pred=tBpB,
|
|
)
|
|
k_tile_index = k_tile_index + 1
|
|
cute.arch.cp_async_commit_group()
|
|
smem_pipe_write = smem_pipe_read
|
|
smem_pipe_read = smem_pipe_read + 1
|
|
if smem_pipe_read == num_smem_stages:
|
|
smem_pipe_read = 0
|
|
|
|
# Sync before epilogue
|
|
cute.arch.cp_async_wait_group(0)
|
|
cute.arch.sync_threads()
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Epilogue with fusion
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
|
|
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
|
|
|
|
# Copy results of D back to shared memory
|
|
cute.autovec_copy(tCrD, tCsC)
|
|
|
|
# Create counting tensor for C
|
|
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
|
mcC = cute.make_identity_tensor(
|
|
(
|
|
cute.size(ceilM) * self.cta_tiler[0],
|
|
cute.size(ceilN) * self.cta_tiler[1],
|
|
1,
|
|
)
|
|
)
|
|
cC = cute.local_tile(
|
|
mcC[None, None, bidz],
|
|
tiler=self.cta_tiler,
|
|
coord=tiler_coord,
|
|
proj=(1, 1, None),
|
|
)
|
|
tCcC = thr_copy_C.partition_S(cC)
|
|
|
|
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
|
|
# Wait for all writes to shared memory to finish before starting copies
|
|
# using the new layouts
|
|
cute.arch.sync_threads()
|
|
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
|
|
|
|
# Create predication tensor for m
|
|
tCpC = cute.make_fragment(
|
|
cute.make_layout(
|
|
(
|
|
tCgC_epilogue.shape[0][1],
|
|
cute.size(tCgC_epilogue, mode=[1]),
|
|
cute.size(tCgC_epilogue, mode=[2]),
|
|
),
|
|
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
for rest_v in range(tCpC.shape[0]):
|
|
for m in range(tCpC.shape[1]):
|
|
tCpC[rest_v, m, 0] = cute.elem_less(
|
|
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
|
|
)
|
|
|
|
# Copy to global memory using better vectorization
|
|
for rest_v in range(tCpC.shape[0]):
|
|
for n in range(tCpC.shape[2]):
|
|
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
|
|
cute.copy(
|
|
tiled_copy_C,
|
|
tCrC_epilogue[None, None, n],
|
|
tCgC_epilogue[None, None, n],
|
|
pred=tCpC[None, None, n],
|
|
)
|
|
return
|
|
|
|
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
|
|
major_mode_size = (
|
|
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
|
)
|
|
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
|
|
|
|
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
|
swizzle_bits = min(swizzle_bits, 3)
|
|
|
|
layout_atom_outer = (
|
|
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
|
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
|
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
|
)
|
|
layout_atom = cute.make_composed_layout(
|
|
cute.make_swizzle(swizzle_bits, 3, 3),
|
|
0,
|
|
layout_atom_outer,
|
|
)
|
|
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
|
|
return layout
|
|
|
|
def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler):
|
|
major_mode_size = (
|
|
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
|
)
|
|
|
|
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
|
swizzle_bits = min(swizzle_bits, 3)
|
|
|
|
layout_atom_outer = (
|
|
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
|
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
|
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
|
)
|
|
layout_atom = cute.make_composed_layout(
|
|
cute.make_swizzle(swizzle_bits, 3, 4),
|
|
0,
|
|
layout_atom_outer,
|
|
)
|
|
|
|
# Due to the thread layout of the mma, remove swizzle in C to
|
|
# prevent shared memory fragments owned by an single thread from
|
|
# holding swizzles
|
|
if major_mode == utils.LayoutEnum.COL_MAJOR:
|
|
layout_atom = cute.make_composed_layout(
|
|
cute.make_swizzle(0, 3, 4), 0, layout_atom_outer
|
|
)
|
|
layout = cute.tile_to_shape(
|
|
layout_atom,
|
|
smem_tiler,
|
|
(0, 1),
|
|
)
|
|
return layout
|
|
|
|
def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits):
|
|
copy_elems = copy_bits // dtype.width
|
|
shape_dim_1 = cute.size(self.bK) // copy_elems
|
|
# thread layout for copy
|
|
thread_layout = cute.make_layout(
|
|
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
)
|
|
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
|
shape_dim_0 = cute.size(self.bM) // copy_elems
|
|
thread_layout = cute.make_layout(
|
|
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
)
|
|
# Value layout for copy
|
|
value_layout = (
|
|
cute.make_layout((1, copy_elems))
|
|
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
|
else cute.make_layout((copy_elems, 1))
|
|
)
|
|
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
|
|
|
def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits):
|
|
copy_elems = copy_bits // dtype.width
|
|
shape_dim_1 = cute.size(self.bN) // copy_elems
|
|
# thread layout for copy
|
|
thread_layout = cute.make_layout(
|
|
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
|
)
|
|
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
|
shape_dim_0 = cute.size(self.bM) // copy_elems
|
|
thread_layout = cute.make_layout(
|
|
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
|
)
|
|
value_layout = (
|
|
cute.make_layout((1, copy_elems))
|
|
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
|
else cute.make_layout((copy_elems, 1))
|
|
)
|
|
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
|
|
|
def raster_tile(self, i, j, f):
|
|
new_i = i // f
|
|
new_j = (i % f) + (j * f)
|
|
return (new_i, new_j)
|
|
|
|
|
|
def run(
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
mnkl: Tuple[int, int, int, int],
|
|
atom_layout_mnk: Tuple[int, int, int],
|
|
warmup_iterations: int = 2,
|
|
iterations: int = 100,
|
|
skip_ref_check: bool = False,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
print(f"Running Ampere tensor core GEMM example:")
|
|
print(f"mnkl: {mnkl}")
|
|
print(
|
|
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
|
)
|
|
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
|
print(f"Atoms layout: {atom_layout_mnk}")
|
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
print(f"Iterations: {iterations}")
|
|
print(f"Skip reference checking: {skip_ref_check}")
|
|
print(f"Use cold L2: {use_cold_l2}")
|
|
M, N, K, L = mnkl
|
|
|
|
# Create and permute tensor A/B/C
|
|
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
|
|
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
|
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
|
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
|
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
torch_tensor = (
|
|
torch.empty(*shape, dtype=torch.int32)
|
|
.random_(-2, 2)
|
|
.to(dtype=cutlass_torch.dtype(dtype))
|
|
.permute(permute_order)
|
|
.cuda()
|
|
)
|
|
# assume input is 16B aligned
|
|
cute_tensor = (
|
|
from_dlpack(torch_tensor, assumed_align=16)
|
|
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
|
|
.mark_compact_shape_dynamic(
|
|
mode=(1 if not is_mode0_major else 0),
|
|
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
|
|
divisibility=(128 // dtype.width),
|
|
)
|
|
)
|
|
return cute_tensor, torch_tensor
|
|
|
|
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
|
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
|
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
|
|
|
tensor_op_gemm = TensorOpGemm(
|
|
ab_dtype,
|
|
c_dtype,
|
|
acc_dtype,
|
|
atom_layout_mnk,
|
|
)
|
|
|
|
print("Compiling kernel with cute.compile ...")
|
|
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
|
|
|
|
print("Executing GEMM kernel...")
|
|
|
|
if not skip_ref_check:
|
|
ref = torch.einsum(
|
|
"mkl,nkl->mnl",
|
|
a_torch.to(dtype=torch.float32),
|
|
b_torch.to(dtype=torch.float32),
|
|
).to(cutlass_torch.dtype(c_dtype))
|
|
compiled_gemm(mA, mB, mC)
|
|
print("Verifying results...")
|
|
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
|
print("Results verified successfully!")
|
|
|
|
def generate_tensors():
|
|
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
|
|
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
|
|
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
|
|
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
a_torch.numel() * a_torch.element_size()
|
|
+ b_torch.numel() * b_torch.element_size()
|
|
+ c_torch.numel() * c_torch.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
avg_time_us = testing.benchmark(
|
|
compiled_gemm,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
use_cuda_graphs=False,
|
|
)
|
|
|
|
return avg_time_us # Return execution time in microseconds
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
|
try:
|
|
return tuple(int(x.strip()) for x in s.split(","))
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="example of multistage block matmul with CuTe on GPU"
|
|
)
|
|
parser.add_argument(
|
|
"--mnkl", type=parse_comma_separated_ints, default=(112, 136, 40, 1)
|
|
)
|
|
parser.add_argument(
|
|
"--atom_layout_mnk", type=parse_comma_separated_ints, default=(2, 2, 1)
|
|
)
|
|
parser.add_argument(
|
|
"--ab_dtype",
|
|
type=cutlass.dtype,
|
|
choices=[cutlass.Float16],
|
|
default=cutlass.Float16,
|
|
)
|
|
parser.add_argument(
|
|
"--acc_dtype",
|
|
type=cutlass.dtype,
|
|
choices=[cutlass.Float32],
|
|
default=cutlass.Float32,
|
|
)
|
|
parser.add_argument(
|
|
"--c_dtype",
|
|
type=cutlass.dtype,
|
|
choices=[cutlass.Float16],
|
|
default=cutlass.Float16,
|
|
)
|
|
parser.add_argument("--a_major", choices=["k", "m"], default="m")
|
|
parser.add_argument("--b_major", choices=["k", "n"], default="n")
|
|
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
|
parser.add_argument("--warmup_iterations", default=2, type=int)
|
|
parser.add_argument("--iterations", default=100, type=int)
|
|
parser.add_argument("--skip_ref_check", action="store_true")
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
run(
|
|
args.a_major,
|
|
args.b_major,
|
|
args.c_major,
|
|
args.ab_dtype,
|
|
args.c_dtype,
|
|
args.acc_dtype,
|
|
args.mnkl,
|
|
args.atom_layout_mnk,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
print("PASS")
|