Files
cutlass/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

2928 lines
110 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from typing import Type, Tuple, Union
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.utils.blackwell_helpers as sm100_utils
import math
"""
High-performance persistent blockwise dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture
using CUTE DSL.
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K")
- Matrix B is NxKxL, L is batch dimension, B can be column-major("K")
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
- Each block will apply the scale factor A
- Each row will apply the scale factor B
- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB
This GEMM kernel supports the following features:
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations
- Implements TMA multicast with cluster to reduce L2 memory traffic
- Support persistent tile scheduling to better overlap memory load/store with mma between tiles
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
This GEMM works as follows:
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations.
2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
3. EPILOGUE warp:
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
- Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final
- Type convert Final matrix to output type.
- Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations.
SM100 tcgen05.mma instructions operate as follows:
- Read matrix A from SMEM
- Read matrix B from SMEM
- Write accumulator to TMEM
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
.. code-block:: bash
python examples/blackwell/blockwise_gemm/blockwise_gemm.py \
--ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \
--scale_dtype Float32 \
--mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \
--mnkl 4096,4096,4096,4
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/blackwell/blockwise_gemm/blockwise_gemm.py \
--ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \
--scale_dtype Float32 \
--mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \
--mnkl 4096,4096,4096,4
Constraints are same as dense_gemm.py:
* Supported input data types: fp8 (e4m3fn)
see detailed valid dtype combinations in below BlockwiseGemmKernel class documentation
* A/B tensor must have the same data type
* Mma tiler M must be 64/128/256
* Mma tiler N must be 128, align with the scaleB requirement
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
* Cluster shape M must be multiple of 2
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned
"""
class BlockwiseGemmKernel:
"""This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2)
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
:type use_2cta_instrs: bool
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
:note: Supported A/B data types:
- Float8E4M3FN
:note: Supported accumulator data types:
- Float32
:note: Supported C data types:
- Float16/BFloat16
- Other data types are not supported for accuracy issues
:note: Constraints:
- MMA tiler M must be 64/128/256
- MMA tiler N must be 128
- Cluster shape M must be multiple of 2
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
Example:
>>> gemm = BlockwiseGemmKernel(
... acc_dtype=cutlass.Float32,
... use_2cta_instrs=True,
... mma_tiler_mn=(128, 128),
... cluster_shape_mn=(2, 2)
... )
>>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream)
"""
def __init__(
self,
acc_dtype: Type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
):
"""Initializes the configuration for a Blackwell blockwise dense GEMM kernel.
This configuration includes several key aspects:
1. MMA Instruction Settings (tcgen05):
- acc_dtype: Data types for MMA accumulator.
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
with cta_group=2 should be used.
2. Cluster Shape:
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[cutlass.Numeric]
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
:type mma_tiler_mn: Tuple[int, int]
:param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
:type use_2cta_instrs: bool
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
:type cluster_shape_mn: Tuple[int, int]
"""
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
self.use_2cta_instrs = use_2cta_instrs
self.cluster_shape_mn = cluster_shape_mn
# K dimension is deferred in _setup_attributes
self.mma_tiler = (*mma_tiler_mn, 1)
self.cta_group = (
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
self.occupancy = 1
# Set specialized warp ids
self.acc_update_warp_id = (0, 1, 2, 3)
self.epilog_warp_id = (4, 5, 6, 7)
self.mma_warp_id = 8
self.tma_warp_id = 9
self.scale_warp_id = 10
self.sched_warp_id = 11
self.threads_per_warp = 32
self.threads_per_cta = self.threads_per_warp * len(
(
*self.acc_update_warp_id,
*self.epilog_warp_id,
self.mma_warp_id,
self.tma_warp_id,
self.scale_warp_id,
self.sched_warp_id,
)
)
self.threads_wo_sched = self.threads_per_warp * len(
(
*self.acc_update_warp_id,
*self.epilog_warp_id,
self.mma_warp_id,
self.tma_warp_id,
self.scale_warp_id,
)
)
self.num_regs_uniform_warps = 64
self.num_regs_sched_warps = 64
self.num_regs_epilogue_warps = 216
self.num_regs_acc_update_warps = 216
# Set barrier for cta sync, epilogue sync and tmem ptr sync
self.cta_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=self.threads_per_cta,
)
self.epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=32 * len(self.epilog_warp_id),
)
self.tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=3,
num_threads=32
* len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)),
)
self.sched_sync_barrier = pipeline.NamedBarrier(
barrier_id=4,
num_threads=self.threads_per_warp,
)
self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
# TMEM offset for final accumulator
self.tmem_final_offset = 384
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
(data types, leading dimensions) and kernel settings:
- Configuring tiled MMA
- Computing MMA/cluster/tile shapes
- Computing cluster layout
- Computing multicast CTAs for A/B
- Computing epilogue subtile
- Setting up A/B/C stage counts in shared memory
- Computing A/B/C shared memory layout
- Computing tensor memory allocation columns
"""
# Configure tiled mma
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
)
# Compute mma/cluster/tile shapes
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (
self.mma_tiler[0],
self.mma_tiler[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
# Compute cluster layout
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
self.scale_granularity_m = 1
self.scale_granularity_n = 128
self.scale_granularity_k = 128
self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m
self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n
self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k
if self.scale_k_per_tile != 1:
raise ValueError("scale_k_per_tile must be 1")
if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]:
raise ValueError("scale_m_per_tile must be cta_tile_m")
if self.scale_n_per_tile != 1:
raise ValueError("scale_n_per_tile must be 1")
# Compute number of multicast CTAs for A/B
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
# Compute epilogue subtile
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
# Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory
(
self.num_acc_stage,
self.num_ab_stage,
self.num_c_stage,
self.num_scale_stage,
self.num_tile_stage,
) = self._compute_stages(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.epi_tile,
self.c_dtype,
self.c_layout,
self.sfa_dtype,
self.sfb_dtype,
self.scale_m_per_tile * self.scale_k_per_tile,
self.scale_n_per_tile * self.scale_k_per_tile,
self.num_smem_capacity,
self.occupancy,
)
# Compute A/B/C/Scale shared memory layout
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.num_ab_stage,
)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
tiled_mma,
self.mma_tiler,
self.b_dtype,
self.num_ab_stage,
)
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
self.c_dtype,
self.c_layout,
self.epi_tile,
self.num_c_stage,
)
self.sfa_smem_layout_staged = cute.make_layout(
(
(self.scale_granularity_m, self.scale_m_per_tile),
(self.scale_granularity_k, self.scale_k_per_tile),
self.num_scale_stage,
),
stride=(
(0, self.scale_k_per_tile),
(0, 1),
self.scale_k_per_tile * self.scale_m_per_tile,
),
)
self.sfb_smem_layout_staged = cute.make_layout(
(
(self.scale_granularity_n, self.scale_n_per_tile),
(self.scale_granularity_k, self.scale_k_per_tile),
self.num_scale_stage,
),
stride=(
(0, self.scale_k_per_tile),
(0, 1),
self.scale_k_per_tile * self.scale_n_per_tile,
),
)
# Compute the number of tensor memory allocation columns
self.num_tmem_alloc_cols = 512
@cute.jit
def __call__(
self,
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
sfa: cute.Tensor,
sfb: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
"""Execute the GEMM operation in steps:
- Setup static attributes before smem/grid/tma computation
- Setup TMA load/store atoms and tensors
- Compute grid size with regard to hardware constraints
- Define shared storage for kernel
- Launch the kernel synchronously
:param a: Input tensor A
:type a: cute.Tensor
:param b: Input tensor B
:type b: cute.Tensor
:param c: Output tensor C
:type c: cute.Tensor
:param sfa: Scale factor tensor A
:type sfa: cute.Tensor
:param sfb: Scale factor tensor B
:type sfb: cute.Tensor
:param max_active_clusters: Maximum number of active clusters
:type max_active_clusters: cutlass.Constexpr
:param stream: CUDA stream for asynchronous execution
:type stream: cuda.CUstream
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
:type epilogue_op: cutlass.Constexpr
:raises TypeError: If input data types are incompatible with the MMA instruction.
"""
# Setup static attributes before smem/grid/tma computation
self.a_dtype: Type[cutlass.Numeric] = a.element_type
self.b_dtype: Type[cutlass.Numeric] = b.element_type
self.c_dtype: Type[cutlass.Numeric] = c.element_type
self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type
self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(c)
# Check if input data types are compatible with MMA instruction
if cutlass.const_expr(self.a_dtype != self.b_dtype):
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
)
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
a_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
),
)
# Setup TMA load for B
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
b_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
),
)
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
# Setup TMA store for C
tma_atom_c = None
tma_tensor_c = None
c_cta_v_layout = cute.composition(
cute.make_identity_layout(c.shape), self.epi_tile
)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
c_cta_v_layout,
)
tensor_sfa = cute.make_tensor(
sfa.iterator,
cute.make_layout(
(
(self.scale_granularity_m, sfa.shape[0]),
(self.scale_granularity_k, sfa.shape[1]),
sfa.shape[2],
),
stride=(
(0, sfa.layout.stride[0]),
(0, sfa.layout.stride[1]),
sfa.layout.stride[2],
),
),
)
tensor_sfb = cute.make_tensor(
sfb.iterator,
cute.make_layout(
(
(self.scale_granularity_n, sfb.shape[0]),
(self.scale_granularity_k, sfb.shape[1]),
sfb.shape[2],
),
stride=(
(0, sfb.layout.stride[0]),
(0, sfb.layout.stride[1]),
sfb.layout.stride[2],
),
),
)
# Compute grid size
self.tile_sched_params, grid = self._compute_grid(
c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters
)
self.buffer_align_bytes = 1024
c_smem_size = cute.cosize(self.c_smem_layout_staged.outer)
# Define shared storage for kernel
@cute.struct
class SharedStorage:
# (bidx, bidy, bidz, valid)
sInfo: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage],
# 1 byte alignment
1,
]
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
scale_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_scale_stage * 2
]
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tile_info_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_tile_stage * 2
]
epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC: cute.struct.Align[
cute.struct.MemRange[
self.c_dtype,
c_smem_size,
],
self.buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
sA: cute.struct.Align[
cute.struct.MemRange[
self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)
],
self.buffer_align_bytes,
]
# (MMA, MMA_N, MMA_K, STAGE)
sB: cute.struct.Align[
cute.struct.MemRange[
self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)
],
self.buffer_align_bytes,
]
# (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage)
sSFA: cute.struct.Align[
cute.struct.MemRange[
self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged)
],
self.buffer_align_bytes,
]
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
sSFB: cute.struct.Align[
cute.struct.MemRange[
self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged)
],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
# Launch the kernel synchronously
self.kernel(
tiled_mma,
tma_atom_a,
tma_tensor_a,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c,
tensor_sfa,
tensor_sfb,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.c_smem_layout_staged,
self.sfa_smem_layout_staged,
self.sfb_smem_layout_staged,
self.epi_tile,
self.tile_sched_params,
epilogue_op,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
)
return
# GPU device kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
mSFA_mkl: cute.Tensor,
mSFB_nkl: cute.Tensor,
cluster_layout_vmnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
sfa_smem_layout_staged: cute.Layout,
sfb_smem_layout_staged: cute.Layout,
epi_tile: cute.Tile,
tile_sched_params: utils.PersistentTileSchedulerParams,
epilogue_op: cutlass.Constexpr,
):
"""
GPU device kernel performing the Persistent batched GEMM computation.
"""
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
lane_idx = cute.arch.lane_idx()
#
# Prefetch tma desc
#
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
#
# Setup cta/thread coordinates
#
# Coords inside cluster
bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
cta_rank_in_cluster
)
# Coord inside cta
tidx, _, _ = cute.arch.thread_idx()
#
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
#
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
tmem_holding_buf = storage.tmem_holding_buf
# Initialize mainloop ab_pipeline (barrier) and states
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
)
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
)
# Initialize mainloop scale_pipeline (barrier) and states
scale_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * 1,
)
scale_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * len(self.epilog_warp_id),
)
scale_pipeline = pipeline.PipelineCpAsync.create(
barrier_storage=storage.scale_mbar_ptr.data_ptr(),
num_stages=self.num_scale_stage,
producer_group=scale_pipeline_producer_group,
consumer_group=scale_pipeline_consumer_group,
)
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = len(self.epilog_warp_id) * (
2 if use_2cta_instrs else 1
)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_acc_consumer_threads
)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
cta_layout_vmnk=cluster_layout_vmnk,
)
# Initialize epilogue pipeline (barrier) and states
epi_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * len(self.acc_update_warp_id),
)
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * len(self.epilog_warp_id),
)
epi_pipeline = pipeline.PipelineAsync.create(
barrier_storage=storage.epi_mbar_ptr.data_ptr(),
num_stages=1,
producer_group=epi_pipeline_producer_group,
consumer_group=epi_pipeline_consumer_group,
)
# Initialize tile info pipeline (barrier) and states
tile_info_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * 1,
)
tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_wo_sched,
)
tile_info_pipeline = pipeline.PipelineAsync.create(
barrier_storage=storage.tile_info_mbar_ptr.data_ptr(),
num_stages=self.num_tile_stage,
producer_group=tile_info_pipeline_producer_group,
consumer_group=tile_info_pipeline_consumer_group,
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
barrier_for_retrieve=self.tmem_alloc_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
)
# Cluster arrive after barrier init
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_arrive_relaxed()
#
# Setup smem tensor A/B/C/Scale
#
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC = storage.sC.get_tensor(
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
)
# (MMA, MMA_M, MMA_K, STAGE)
sA = storage.sA.get_tensor(
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
)
# (MMA, MMA_N, MMA_K, STAGE)
sB = storage.sB.get_tensor(
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
)
# (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
# (bidx, bidy, bidz, valid)
info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4))
sInfo = storage.sInfo.get_tensor(info_layout)
#
# Compute multicast mask for A/B buffer full
#
a_full_mcast_mask = None
b_full_mcast_mask = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
)
#
# Local_tile partition global tensors
#
# (bM, bK, loopM, loopK, loopL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bN, bK, loopN, loopK, loopL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, loopM, loopN, loopL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
# (bM, bK, loopM, loopK, loopL)
gSFA_mkl = cute.local_tile(
mSFA_mkl,
cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)),
(None, None, None),
)
# (bN, bK, loopN, loopK, loopL)
gSFB_nkl = cute.local_tile(
mSFB_nkl,
cute.slice_(self.cta_tile_shape_mnk, (0, None, None)),
(None, None, None),
)
# coordinate
cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl))
cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl))
# (bM, bK, loopM, loopK, loopL)
cSFA = cute.local_tile(
cSFA_mkl,
cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)),
(None, None, None),
)
# (bN, bK, loopN, loopK, loopL)
cSFB = cute.local_tile(
cSFB_nkl,
cute.slice_(self.cta_tile_shape_mnk, (0, None, None)),
(None, None, None),
)
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
#
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
tCgC = thr_mma.partition_C(gC_mnl)
# scale viewed as C tensor
sSFA_view_as_C_layout = cute.make_layout(
(
(self.scale_granularity_m, self.scale_m_per_tile),
self.cta_tile_shape_mnk[1],
self.num_scale_stage,
),
stride=((0, 1), 0, self.scale_m_per_tile),
)
sSFB_view_as_C_layout = cute.make_layout(
(
self.cta_tile_shape_mnk[0],
(self.scale_granularity_n, self.scale_n_per_tile),
self.num_scale_stage,
),
stride=(0, (0, 1), self.scale_n_per_tile),
)
sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout)
sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout)
#
# Partition global/shared tensor for TMA load A/B
#
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
a_cta_layout,
cute.group_modes(sA, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
b_cta_layout,
cute.group_modes(sB, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
#
# Partition global/shared tensor for TMA load A/B
#
# load scaleA/scaleB
atom_copy = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
mSFA_mkl.element_type,
num_bits_per_copy=mSFA_mkl.element_type.width,
)
tiled_copy_sfa = cute.make_tiled_copy_tv(
atom_copy, cute.make_layout((32,)), cute.make_layout((1,))
)
tiled_copy_sfb = cute.make_tiled_copy_tv(
atom_copy, cute.make_layout((32,)), cute.make_layout((1,))
)
thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx)
thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl)
tAsSFA = thr_copy_sfa.partition_D(sSFA)
tAcSFA = thr_copy_sfa.partition_S(cSFA)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopN, loopK, loopL)
tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl)
tBsSFB = thr_copy_sfb.partition_D(sSFB)
tBcSFB = thr_copy_sfb.partition_S(cSFB)
#
# Partition shared/tensor memory tensor for TiledMMA_A/B/C
#
# (MMA, MMA_M, MMA_K, STAGE)
tCrA = tiled_mma.make_fragment_A(sA)
# (MMA, MMA_N, MMA_K, STAGE)
tCrB = tiled_mma.make_fragment_B(sB)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
#
# Cluster wait before tensor memory alloc
#
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_wait()
else:
self.cta_sync_barrier.arrive_and_wait()
#
# Specialized Schedule warp
#
if warp_idx == self.sched_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps)
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
# First tile
work_tile = tile_sched.initial_work_tile_info()
tile_info_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_tile_stage
)
while work_tile.is_valid_tile:
# query next tile
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
# acquire tile info pipeline
tile_info_pipeline.producer_acquire(tile_info_producer_state)
# store the tile info
cur_tile_coord = work_tile.tile_idx
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2]
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
)
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.sched_sync_barrier.arrive_and_wait()
# commit tile info pipeline
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()
tile_info_pipeline.producer_tail(tile_info_producer_state)
#
# Specialized TMA load warp
#
if warp_idx == self.tma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
# First tile
work_tile = tile_sched.initial_work_tile_info()
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_stage
)
# get the first tile info
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
# initialize the tile info
tile_info[0] = cur_tile_coord[0]
tile_info[1] = cur_tile_coord[1]
tile_info[2] = cur_tile_coord[2]
tile_info[3] = work_tile.is_valid_tile
is_valid_tile = cutlass.Boolean(1)
is_valid_tile = tile_info[3] == 1
while is_valid_tile:
mma_tile_coord_mnl = (
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
tile_info[1],
tile_info[2],
)
#
# Slice to per mma tile index
#
# ((atom_v, rest_v), loopK)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
# ((atom_v, rest_v), loopK)
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
ab_producer_state.reset_count()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
#
# Tma load loop
#
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
tAgA_k = tAgA_slice[(None, ab_producer_state.count)]
tBgB_k = tBgB_slice[(None, ab_producer_state.count)]
tAsA_pipe = tAsA[(None, ab_producer_state.index)]
tBsB_pipe = tBsB[(None, ab_producer_state.index)]
tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state)
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(
ab_producer_state, peek_ab_empty_status
)
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
ab_producer_state.advance()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
# TMA load A/B
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=tma_bar,
mcast_mask=a_full_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_k,
tBsB_pipe,
tma_bar_ptr=tma_bar,
mcast_mask=b_full_mcast_mask,
)
#
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
for idx in cutlass.range(4, unroll_full=True):
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
#
# Wait A/B buffer empty
#
ab_pipeline.producer_tail(ab_producer_state)
#
# Specialized Scale load warp
#
if warp_idx == self.scale_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
# First tile
work_tile = tile_sched.initial_work_tile_info()
scale_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_scale_stage
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_stage
)
# get the first tile info
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
# initialize the tile info
tile_info[0] = cur_tile_coord[0]
tile_info[1] = cur_tile_coord[1]
tile_info[2] = cur_tile_coord[2]
tile_info[3] = work_tile.is_valid_tile
is_valid_tile = cutlass.Boolean(1)
is_valid_tile = tile_info[3] == 1
while is_valid_tile:
#
# Prepare the mask for scaleA/scaleB
#
tApSFA = cute.make_rmem_tensor(
cute.make_layout(
cute.filter_zeros(
cute.slice_(tAsSFA, (None, None, None, 0))
).shape
),
cutlass.Boolean,
)
tBpSFB = cute.make_rmem_tensor(
cute.make_layout(
cute.filter_zeros(
cute.slice_(tBsSFB, (None, None, None, 0))
).shape
),
cutlass.Boolean,
)
# Peek (try_wait) SCALE buffer empty
scale_producer_state.reset_count()
peek_scale_empty_status = cutlass.Boolean(1)
if scale_producer_state.count < k_tile_cnt:
peek_scale_empty_status = scale_pipeline.producer_try_acquire(
scale_producer_state
)
#
# load loop
#
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
#
# Slice to per mma tile index
#
tAsSFA_pipe = cute.filter_zeros(
tAsSFA[(None, None, None, scale_producer_state.index)]
)
tBsSFB_pipe = cute.filter_zeros(
tBsSFB[(None, None, None, scale_producer_state.index)]
)
tAgSFA_k = cute.filter_zeros(
tAgSFA_mkl[
(
None,
None,
None,
tile_info[0],
scale_producer_state.count,
tile_info[2],
)
]
)
tBgSFB_k = cute.filter_zeros(
tBgSFB_nkl[
(
None,
None,
None,
tile_info[1],
scale_producer_state.count,
tile_info[2],
)
]
)
tAcSFA_compact = cute.filter_zeros(
cute.slice_(
tAcSFA,
(
None,
None,
None,
tile_info[0],
scale_producer_state.count,
tile_info[2],
),
)
)
tBcSFB_compact = cute.filter_zeros(
cute.slice_(
tBcSFB,
(
None,
None,
None,
tile_info[1],
scale_producer_state.count,
tile_info[2],
),
)
)
for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])):
tApSFA[((0, 0), i, (0, 0))] = cute.elem_less(
tAcSFA_compact[(i)][0], mSFA_mkl.shape[0]
)
for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])):
tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less(
tBcSFB_compact[(i)][0], mSFB_nkl.shape[0]
)
# Conditionally wait for Scale buffer empty
scale_pipeline.producer_acquire(
scale_producer_state, peek_scale_empty_status
)
# load scaleA/scaleB
cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA)
cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB)
scale_pipeline.producer_commit(scale_producer_state)
# Peek (try_wait) Scale buffer empty
scale_producer_state.advance()
peek_scale_empty_status = cutlass.Boolean(1)
if scale_producer_state.count < k_tile_cnt:
peek_scale_empty_status = scale_pipeline.producer_try_acquire(
scale_producer_state
)
#
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
for idx in cutlass.range(4, unroll_full=True):
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
#
# Wait Scale buffer empty
#
scale_pipeline.producer_tail(scale_producer_state)
#
# Specialized MMA warp
#
if warp_idx == self.mma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
#
# Bar sync for retrieve tensor memory ptr from shared mem
#
tmem.wait_for_alloc()
#
# Retrieving tensor memory ptr and make accumulator tensor
#
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
ab_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_stage
)
# get the first tile info
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
# initialize the tile info
tile_info[0] = cur_tile_coord[0]
tile_info[1] = cur_tile_coord[1]
tile_info[2] = cur_tile_coord[2]
tile_info[3] = work_tile.is_valid_tile
is_valid_tile = cutlass.Boolean(1)
is_valid_tile = tile_info[3] == 1
while is_valid_tile:
# Peek (try_wait) AB buffer full for k_tile = 0
ab_consumer_state.reset_count()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state
)
# Peek (try_wait) Acc buffer empty for k_tile = 0
acc_producer_state.reset_count()
peek_acc_empty_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_acc_empty_status = acc_pipeline.producer_try_acquire(
acc_producer_state
)
#
# Mma mainloop
#
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Set tensor memory buffer for current tile
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
#
# Wait for accumulator buffer empty
#
if is_leader_cta:
acc_pipeline.producer_acquire(
acc_producer_state, peek_acc_empty_status
)
#
# Reset the ACCUMULATE field for each tile
#
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
if is_leader_cta:
# Conditionally wait for AB buffer full
ab_pipeline.consumer_wait(
ab_consumer_state, peek_ab_full_status
)
# tCtAcc += tCrA * tCrB
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (
None,
None,
kblock_idx,
ab_consumer_state.index,
)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
ab_pipeline.consumer_release(ab_consumer_state)
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
ab_consumer_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_tile_cnt:
if is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state
)
#
# Async arrive accumulator buffer full(each kblock)
#
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
# Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1
acc_producer_state.advance()
if acc_producer_state.count < k_tile_cnt:
if is_leader_cta:
peek_acc_empty_status = acc_pipeline.producer_try_acquire(
acc_producer_state
)
#
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
for idx in cutlass.range(4, unroll_full=True):
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
#
# Wait for accumulator buffer empty
#
acc_pipeline.producer_tail(acc_producer_state)
#
# Specialized acc update warps
#
if warp_idx <= self.acc_update_warp_id[-1]:
cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps)
#
# Bar sync for retrieve tensor memory ptr from shared memory
#
tmem.wait_for_alloc()
#
# Retrieving tensor memory ptr and make accumulator tensor
#
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc_final = cute.make_tensor(
tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout
)
#
# Partition for epilogue
#
epi_tidx = tidx % 128
(
tiled_copy_t2r,
tiled_copy_r2t,
tTR_tAcc_base,
tTR_rAcc,
tTR_rAcc_final,
tTR_sSFA,
tTR_sSFB,
tRT_rAcc,
tRT_tAcc_base,
) = self.acc_update_tmem_copy_and_partition(
epi_tidx,
tCtAcc_base,
tCtAcc_final,
tCgC,
sSFA_view_as_C,
sSFB_view_as_C,
epi_tile,
)
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
scale_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_scale_stage
)
epi_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, 1
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_stage
)
# get the first tile info
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
# initialize the tile info
tile_info[0] = cur_tile_coord[0]
tile_info[1] = cur_tile_coord[1]
tile_info[2] = cur_tile_coord[2]
tile_info[3] = work_tile.is_valid_tile
is_valid_tile = cutlass.Boolean(1)
is_valid_tile = tile_info[3] == 1
while is_valid_tile:
# initialize the final accumulator
tTR_rAcc_final.fill(0.0)
tTR_rSFA = cute.make_rmem_tensor(
cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape,
self.acc_dtype,
)
tTR_rSFB = cute.make_rmem_tensor(
cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape,
self.acc_dtype,
)
scale_consumer_state.reset_count()
peek_scale_full_status = cutlass.Boolean(1)
if scale_consumer_state.count < k_tile_cnt:
peek_scale_full_status = scale_pipeline.consumer_try_wait(
scale_consumer_state
)
acc_consumer_state.reset_count()
peek_acc_full_status = cutlass.Boolean(1)
if acc_consumer_state.count < k_tile_cnt:
peek_acc_full_status = acc_pipeline.consumer_try_wait(
acc_consumer_state
)
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Set tensor memory buffer for current tile
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
tTR_tAcc = tTR_tAcc_base[
(None, None, None, None, None, acc_consumer_state.index)
]
#
# Wait for scale buffer full
#
scale_pipeline.consumer_wait(
scale_consumer_state, peek_scale_full_status
)
tTR_sSFA_slice = cute.slice_(
tTR_sSFA,
(None, None, None, 0, None, scale_consumer_state.index),
)
tTR_sSFB_slice = cute.slice_(
tTR_sSFB,
(None, None, None, 0, None, scale_consumer_state.index),
)
scale_atom_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.acc_dtype,
num_bits_per_copy=self.acc_dtype.width,
)
cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA)
cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB)
#
# Wait for accumulator buffer full
#
acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status)
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
#
# Update accumulator by scale factor in subtiles
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
for subtile_idx in cutlass.range(subtile_cnt):
#
# Load accumulator from tensor memory buffer to register
#
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
#
# Update accumulator by scale factor
#
tTR_rAcc_subtile = tTR_rAcc_final[
(None, None, None, subtile_idx)
]
tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)]
tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)]
acc_vec = tTR_rAcc.load()
final_vec = tTR_rAcc_subtile.load()
scale_a = tTR_rSFA_subtile.load()
scale_b = tTR_rSFB_subtile.load()
scale = scale_a * scale_b
final_vec = acc_vec * scale + final_vec
tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype))
#
# Async arrive accumulator buffer empty
#
scale_pipeline.consumer_release(scale_consumer_state)
scale_consumer_state.advance()
peek_scale_full_status = cutlass.Boolean(1)
if scale_consumer_state.count < k_tile_cnt:
peek_scale_full_status = scale_pipeline.consumer_try_wait(
scale_consumer_state
)
#
# Async arrive accumulator buffer empty
#
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
peek_acc_full_status = cutlass.Boolean(1)
if acc_consumer_state.count < k_tile_cnt:
peek_acc_full_status = acc_pipeline.consumer_try_wait(
acc_consumer_state
)
tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)]
tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc))
#
# Wait for epilogue buffer empty
#
epi_pipeline.producer_acquire(epi_producer_state)
# copy the accumulator to tensor memory buffer
cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc)
cute.arch.fence_view_async_tmem_store()
#
# Async arrive epilogue buffer full
#
epi_pipeline.producer_commit(epi_producer_state)
epi_producer_state.advance()
#
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
for idx in cutlass.range(4, unroll_full=True):
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
#
# Specialized epilogue warps
#
if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]:
cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps)
#
# Alloc tensor memory buffer
#
tmem.allocate(self.num_tmem_alloc_cols)
#
# Bar sync for retrieve tensor memory ptr from shared memory
#
tmem.wait_for_alloc()
#
# Retrieving tensor memory ptr and make accumulator tensor
#
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc_final = cute.make_tensor(
tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout
)
#
# Partition for epilogue
#
epi_tidx = tidx % 128
(
tiled_copy_t2r,
tTR_tAcc_base,
tTR_rAcc,
) = self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs
)
tTR_rC = None
tiled_copy_r2s = None
simt_atom = None
tRS_rC = None
tRS_sC = None
bSG_sC = None
bSG_gC_partitioned = None
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
tiled_copy_t2r, tTR_rC, epi_tidx, sC
)
(
tma_atom_c,
bSG_sC,
bSG_gC_partitioned,
) = self.epilog_gmem_copy_and_partition(
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
)
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
epi_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, 1
)
c_pipeline = None
# Threads/warps participating in tma store pipeline
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=c_producer_group,
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_stage
)
# get the first tile info
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
# initialize the tile info
tile_info[0] = cur_tile_coord[0]
tile_info[1] = cur_tile_coord[1]
tile_info[2] = cur_tile_coord[2]
tile_info[3] = work_tile.is_valid_tile
is_valid_tile = cutlass.Boolean(1)
is_valid_tile = tile_info[3] == 1
num_prev_subtiles = cutlass.Int32(0)
while is_valid_tile:
mma_tile_coord_mnl = (
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
tile_info[1],
tile_info[2],
)
#
# Slice to per mma tile index
#
bSG_gC = None
# ((ATOM_V, REST_V), EPI_M, EPI_N)
bSG_gC = bSG_gC_partitioned[
(
None,
None,
None,
mma_tile_coord_mnl[0],
mma_tile_coord_mnl[1],
mma_tile_coord_mnl[2],
)
]
# Set tensor memory buffer for current tile
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
tTR_tAcc = tTR_tAcc_base[
(None, None, None, None, None, epi_consumer_state.index)
]
#
# Wait for accumulator buffer full
#
epi_pipeline.consumer_wait(epi_consumer_state)
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
#
# Store accumulator to global memory in subtiles
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
for subtile_idx in cutlass.range(subtile_cnt):
#
# Load accumulator from tensor memory buffer to register
#
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
#
# Convert to C type
#
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
tRS_rC.store(acc_vec)
#
# Store C to shared memory
#
num_prev_subtiles = num_prev_subtiles + 1
c_buffer = num_prev_subtiles % self.num_c_stage
cute.copy(
tiled_copy_r2s,
tRS_rC,
tRS_sC[(None, None, None, c_buffer)],
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.epilog_sync_barrier.arrive_and_wait()
#
# TMA store C to global memory
#
if warp_idx == self.epilog_warp_id[0]:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC[(None, subtile_idx)],
)
# Fence and barrier to make sure shared memory store is visible to TMA store
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()
#
# Async arrive accumulator buffer empty
#
epi_pipeline.consumer_release(epi_consumer_state)
epi_consumer_state.advance()
#
# Advance to next tile
#
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
for idx in cutlass.range(4, unroll_full=True):
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
#
# Dealloc the tensor memory buffer
#
tmem.relinquish_alloc_permit()
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(tmem_ptr)
#
# Wait for C store complete
#
c_pipeline.producer_tail()
def acc_update_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,
tAcc: cute.Tensor,
tAcc_final: cute.Tensor,
gC_mnl: cute.Tensor,
sSFA: cute.Tensor,
sSFB: cute.Tensor,
epi_tile: cute.Tile,
) -> Tuple[
cute.TiledCopy,
cute.TiledCopy,
cute.Tensor,
cute.Tensor,
cute.Tensor,
cute.Tensor,
cute.Tensor,
cute.Tensor,
cute.Tensor,
]:
"""
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination).
Partition the scale factor tensor for related copy operations.
:param tidx: The thread index in epilogue warp groups
:type tidx: cutlass.Int32
:param tAcc: The accumulator tensor to be copied and partitioned
:type tAcc: cute.Tensor
:param tAcc_final: The final accumulator tensor to be copied and partitioned
:type tAcc_final: cute.Tensor
:param gC_mnl: The global tensor C
:type gC_mnl: cute.Tensor
:param sSFA: The scale factor tensor for A
:type sSFA: cute.Tensor
:param sSFB: The scale factor tensor for B
:type sSFB: cute.Tensor
:param epi_tile: The epilogue tiler
:type epi_tile: cute.Tile
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
- tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t)
- tTR_tAcc: The partitioned accumulator tensor
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
- tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results
- tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r
- tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r
- tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results
- tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t
:rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]
"""
# Make tiledCopy for tensor memory load
tmem_load_atom = None
tmem_store_atom = None
if cutlass.const_expr(self.mma_tiler[0] == 64):
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)),
self.acc_dtype,
)
elif cutlass.const_expr(self.mma_tiler[0] == 128):
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)),
self.acc_dtype,
)
else:
# default: 16dp
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(1)),
self.acc_dtype,
)
if cutlass.const_expr(self.mma_tiler[0] == 64):
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)),
self.acc_dtype,
)
elif cutlass.const_expr(self.mma_tiler[0] == 128):
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)),
self.acc_dtype,
)
else:
# default: 16dp
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(1)),
self.acc_dtype,
)
tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
tAcc_final_epi = cute.flat_divide(
tAcc_final[((None, None), 0, 0, None)], epi_tile
)
tiled_copy_t2r = tcgen05.make_tmem_copy(
tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)]
)
tiled_copy_r2t = tcgen05.make_tmem_copy(
tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)]
)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
thr_copy_r2t = tiled_copy_r2t.get_slice(tidx)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
sSFA_epi = cute.flat_divide(sSFA, epi_tile)
sSFB_epi = cute.flat_divide(sSFB, epi_tile)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi)
tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
tTR_rAcc_final_ = cute.make_rmem_tensor(
tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype
)
tTR_rAcc_final = cute.group_modes(
tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_)
)
tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi)
tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi)
# (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL)
tRT_rAcc_final_ = cute.make_rmem_tensor(
tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype
)
# (R2T, R2T_M, R2T_N, (EPI_M, EPI_N))
tRT_rAcc_final = cute.group_modes(
tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_)
)
return (
tiled_copy_t2r,
tiled_copy_r2t,
tTR_tAcc,
tTR_rAcc,
tTR_rAcc_final,
tTR_sSFA,
tTR_sSFB,
tRT_rAcc_final,
tRT_tAcc_final,
)
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,
tAcc: cute.Tensor,
gC_mnl: cute.Tensor,
epi_tile: cute.Tile,
use_2cta_instrs: Union[cutlass.Boolean, bool],
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
:param tidx: The thread index in epilogue warp groups
:type tidx: cutlass.Int32
:param tAcc: The accumulator tensor to be copied and partitioned
:type tAcc: cute.Tensor
:param gC_mnl: The global tensor C
:type gC_mnl: cute.Tensor
:param epi_tile: The epilogue tiler
:type epi_tile: cute.Tile
:param use_2cta_instrs: Whether use_2cta_instrs is enabled
:type use_2cta_instrs: bool
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
- tTR_tAcc: The partitioned accumulator tensor
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
"""
# Make tiledCopy for tensor memory load
copy_atom_t2r = sm100_utils.get_tmem_load_op(
self.cta_tile_shape_mnk,
self.c_layout,
self.c_dtype,
self.acc_dtype,
epi_tile,
use_2cta_instrs,
)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
tAcc_epi = cute.flat_divide(
tAcc[((None, None), 0, 0, None)],
epi_tile,
)
# (EPI_TILE_M, EPI_TILE_N)
tiled_copy_t2r = tcgen05.make_tmem_copy(
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
)
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
def epilog_smem_copy_and_partition(
self,
tiled_copy_t2r: cute.TiledCopy,
tTR_rC: cute.Tensor,
tidx: cutlass.Int32,
sC: cute.Tensor,
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
:type tiled_copy_t2r: cute.TiledCopy
:param tTR_rC: The partitioned accumulator tensor
:type tTR_rC: cute.Tensor
:param tidx: The thread index in epilogue warp groups
:type tidx: cutlass.Int32
:param sC: The shared memory tensor to be copied and partitioned
:type sC: cute.Tensor
:type sepi: cute.Tensor
:return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where:
- tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
- tRS_rC: The partitioned tensor C (register source)
- tRS_sC: The partitioned tensor C (smem destination)
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
"""
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
# (R2S, R2S_M, R2S_N)
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
return tiled_copy_r2s, tRS_rC, tRS_sC
def epilog_gmem_copy_and_partition(
self,
tidx: cutlass.Int32,
atom: Union[cute.CopyAtom, cute.TiledCopy],
gC_mnl: cute.Tensor,
epi_tile: cute.Tile,
sC: cute.Tensor,
) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
"""Make tiledCopy for global memory store, then use it to:
- partition register array (source) and global memory (destination) for none TMA store version;
- partition shared memory (source) and global memory (destination) for TMA store version.
:param tidx: The thread index in epilogue warp groups
:type tidx: cutlass.Int32
:param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
:type atom: cute.CopyAtom or cute.TiledCopy
:param gC_mnl: The global tensor C
:type gC_mnl: cute.Tensor
:param epi_tile: The epilogue tiler
:type epi_tile: cute.Tile
:param sC: The shared memory tensor to be copied and partitioned
:type sC: cute.Tensor
:return: A tuple containing :
- For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where:
- tma_atom_c: The TMA copy atom
- bSG_sC: The partitioned shared memory tensor C
- bSG_gC: The partitioned global tensor C
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
"""
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
gC_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
tma_atom_c = atom
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
# ((ATOM_V, REST_V), EPI_M, EPI_N)
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
sC_for_tma_partition,
gC_for_tma_partition,
)
return tma_atom_c, bSG_sC, bSG_gC
@staticmethod
def _compute_stages(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: Tuple[int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
epi_tile: cute.Tile,
c_dtype: Type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
sfa_dtype: Type[cutlass.Numeric],
sfb_dtype: Type[cutlass.Numeric],
sfa_count: int,
sfb_count: int,
num_smem_capacity: int,
occupancy: int,
) -> Tuple[int, int, int]:
"""Computes the number of stages for A/B/C operands based on heuristics.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mnk: tuple[int, int, int]
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param epi_tile: The epilogue tile shape.
:type epi_tile: cute.Tile
:param c_dtype: Data type of operand C (output).
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout of operand C.
:type c_layout: utils.LayoutEnum
:param num_smem_capacity: Total available shared memory capacity in bytes.
:type num_smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:return: A tuple containing the computed number of stages for:
(ACC stages, A/B operand stages, C stages)
:rtype: tuple[int, int, int]
"""
# Default ACC stages
num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6
# Default C stages
num_c_stage = 2
# Default ScaleA/B stages
num_scale_stage = 10
# Default Tile info stages
num_tile_stage = 2
# Calculate smem layout and size for one stage of A, B, and C
a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(
tiled_mma,
mma_tiler_mnk,
a_dtype,
1, # a tmp 1 stage is provided
)
b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
tiled_mma,
mma_tiler_mnk,
b_dtype,
1, # a tmp 1 stage is provided
)
c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
c_dtype,
c_layout,
epi_tile,
1,
)
ab_bytes_per_stage = cute.size_in_bytes(
a_dtype, a_smem_layout_stage_one
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
# 1024B alignment
mbar_helpers_bytes = 1024
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
c_bytes = c_bytes_per_stage * num_c_stage
sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage
sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage
scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024
# Calculate A/B stages:
# Start with total smem per CTA (capacity / occupancy)
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (
num_smem_capacity // occupancy
- (mbar_helpers_bytes + c_bytes + scale_bytes)
) // ab_bytes_per_stage
# Refine epilogue stages:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
num_c_stage += (
num_smem_capacity
- occupancy * ab_bytes_per_stage * num_ab_stage
- occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes)
) // (occupancy * c_bytes_per_stage)
return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage
@staticmethod
def _compute_grid(
c: cute.Tensor,
cta_tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
:param c: The output tensor C
:type c: cute.Tensor
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type cta_tile_shape_mnk: tuple[int, int, int]
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
:type cluster_shape_mn: tuple[int, int]
:param max_active_clusters: Maximum number of active clusters.
:type max_active_clusters: cutlass.Constexpr
:return: A tuple containing:
- tile_sched_params: Parameters for the persistent tile scheduler.
- grid: Grid shape for kernel launch.
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
"""
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
gc = cute.zipped_divide(c, tiler=c_shape)
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
@staticmethod
def _get_tma_atom_kind(
atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean
) -> Union[
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
]:
"""
Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.
:param atom_sm_cnt: The number of SMs
:type atom_sm_cnt: cutlass.Int32
:param mcast: The multicast flag
:type mcast: cutlass.Boolean
:return: The appropriate TMA copy atom kind
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
:raise ValueError: If the atom_sm_cnt is invalid
"""
if atom_sm_cnt == 2 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 2 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
elif atom_sm_cnt == 1 and mcast:
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
elif atom_sm_cnt == 1 and not mcast:
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
@staticmethod
def is_valid_dtypes(
ab_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
) -> bool:
"""
Check if the dtypes are valid
:param ab_dtype: The data type of the A and B operands
:type ab_dtype: Type[cutlass.Numeric]
:param acc_dtype: The data type of the accumulator
:type acc_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:return: True if the dtypes are valid, False otherwise
:rtype: bool
"""
is_valid = True
if ab_dtype not in {
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
}:
is_valid = False
if acc_dtype not in {cutlass.Float32}:
is_valid = False
if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}:
is_valid = False
return is_valid
@staticmethod
def is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
) -> bool:
"""
Check if the mma tiler and cluster shape are valid
:param use_2cta_instrs: Whether to use 2 CTA groups
:type use_2cta_instrs: bool
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
:type cluster_shape_mn: Tuple[int, int]
:return: True if the mma tiler and cluster shape are valid, False otherwise
:rtype: bool
"""
is_valid = True
# Skip invalid mma tile shape
if not (
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
):
is_valid = False
# Skip invalid mma tile n
if mma_tiler_mn[1] not in (128,):
is_valid = False
# Skip illegal cluster shape
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
is_valid = False
# Skip invalid cluster shape
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
if (
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
or cluster_shape_mn[0] <= 0
or cluster_shape_mn[1] <= 0
or not is_power_of_2(cluster_shape_mn[0])
or not is_power_of_2(cluster_shape_mn[1])
):
is_valid = False
return is_valid
@staticmethod
def is_valid_tensor_alignment(
m: int,
n: int,
k: int,
l: int,
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
) -> bool:
"""
Check if the tensor alignment is valid
:param m: The number of rows in the A tensor
:type m: int
:param n: The number of columns in the B tensor
:type n: int
:param k: The number of columns in the A tensor
:type k: int
:param l: The number of columns in the C tensor
:type l: int
:param ab_dtype: The data type of the A and B operands
:type ab_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:param a_major: The major axis of the A tensor
:type a_major: str
:param b_major: The major axis of the B tensor
:type b_major: str
:param c_major: The major axis of the C tensor
:type c_major: str
:return: True if the problem shape is valid, False otherwise
:rtype: bool
"""
is_valid = True
def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
major_mode_idx = 0 if is_mode0_major else 1
num_major_elements = tensor_shape[major_mode_idx]
num_contiguous_elements = 16 * 8 // dtype.width
return num_major_elements % num_contiguous_elements == 0
if (
not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
):
is_valid = False
return is_valid
@staticmethod
def can_implement(
ab_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
m: int,
n: int,
k: int,
l: int,
a_major: str,
b_major: str,
c_major: str,
) -> bool:
"""
Check if the gemm can be implemented
:param ab_dtype: The data type of the A and B operands
:type ab_dtype: Type[cutlass.Numeric]
:param acc_dtype: The data type of the accumulator
:type acc_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use 2 CTA groups
:type use_2cta_instrs: bool
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
:type cluster_shape_mn: Tuple[int, int]
:param m: The number of rows in the A tensor
:type m: int
:param n: The number of columns in the B tensor
:type n: int
:param k: The number of columns in the A tensor
:type k: int
:param l: The number of columns in the C tensor
:type l: int
:param a_major: The major axis of the A tensor
:type a_major: str
:param b_major: The major axis of the B tensor
:type b_major: str
:param c_major: The major axis of the C tensor
:type c_major: str
:return: True if the gemm can be implemented, False otherwise
:rtype: bool
"""
can_implement = True
# Skip unsupported types
if not BlockwiseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype):
can_implement = False
# Skip invalid mma tile shape and cluster shape
if not BlockwiseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
):
can_implement = False
# Skip illegal problem shape for load/store alignment
if not BlockwiseGemmKernel.is_valid_tensor_alignment(
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
):
can_implement = False
# Skip unsupported A/B layout
if not (a_major == "k" and b_major == "k"):
can_implement = False
return can_implement
def create_tensors(
l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, scale_dtype
):
torch.manual_seed(1111)
a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype)
b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype)
c_torch_cpu = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype)
sfa_torch_cpu = cutlass_torch.matrix(l, m, math.ceil(k / 128), True, scale_dtype)
sfb_torch_cpu = cutlass_torch.matrix(
l, math.ceil(n / 128), math.ceil(k / 128), False, scale_dtype
)
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like(
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
)
sfa_tensor, _ = cutlass_torch.cute_tensor_like(
sfa_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16
)
sfb_tensor, _ = cutlass_torch.cute_tensor_like(
sfb_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16
)
return (
a_tensor,
b_tensor,
c_tensor,
sfa_tensor,
sfb_tensor,
a_torch_cpu,
b_torch_cpu,
c_torch_cpu,
sfa_torch_cpu,
sfb_torch_cpu,
c_torch_gpu,
)
def run(
mnkl: Tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
scale_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_2cta_instrs: bool,
tolerance: float,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
"""
print("Running Blackwell Persistent Dense Blockwise GEMM test with:")
print(f"mnkl: {mnkl}")
print(
f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, Scale dtype: {scale_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
print(f"Use TMA Store: {'True'}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
# Unpack parameters
m, n, k, l = mnkl
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run this example!")
if not BlockwiseGemmKernel.can_implement(
ab_dtype,
acc_dtype,
c_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
l,
a_major,
b_major,
c_major,
):
raise TypeError(
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
)
(
a_tensor,
b_tensor,
c_tensor,
sfa_tensor,
sfb_tensor,
a_torch_cpu,
b_torch_cpu,
c_torch_cpu,
sfa_torch_cpu,
sfb_torch_cpu,
c_torch_gpu,
) = create_tensors(
l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, scale_dtype
)
# Configure gemm kernel
gemm = BlockwiseGemmKernel(
acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# Compile gemm kernel
compiled_gemm = cute.compile(
gemm,
a_tensor,
b_tensor,
c_tensor,
sfa_tensor,
sfb_tensor,
max_active_clusters,
current_stream,
)
# Execution
compiled_gemm(
a_tensor,
b_tensor,
c_tensor,
sfa_tensor,
sfb_tensor,
current_stream,
)
torch.cuda.synchronize()
# Compute reference result
if not skip_ref_check:
# update
def pad_and_multiply(scale, tensor):
cm, ck, _ = scale.shape
m, k, _ = tensor.shape
IsGroupWise = False
IsBlockWise = False
if ck == math.ceil(k / 128):
IsGroupWise = True
if cm == math.ceil(m / 128):
IsBlockWise = True
if not IsBlockWise and not IsGroupWise:
raise ValueError("Only support granularity = 128")
k_idx = torch.arange(k, device=scale.device)
if IsGroupWise:
k_idx = k_idx // 128
m_idx = torch.arange(m, device=scale.device)
if IsBlockWise:
m_idx = m_idx // 128
expanded_scale = scale[m_idx[:, None], k_idx, :]
result = expanded_scale * tensor
return result
updated_a = pad_and_multiply(sfa_torch_cpu, a_torch_cpu)
updated_b = pad_and_multiply(sfb_torch_cpu, b_torch_cpu)
ref = torch.einsum("mkl,nkl->mnl", updated_a, updated_b).to(
cutlass_torch.dtype(c_dtype)
)
res = c_torch_gpu.view(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(res.cpu(), ref.cpu(), atol=tolerance, rtol=1e-03)
def generate_tensors():
# Reuse existing CPU reference tensors and create new GPU tensors from them
(
a_tensor,
b_tensor,
c_tensor,
sfa_tensor,
sfb_tensor,
a_torch_cpu,
b_torch_cpu,
c_torch_cpu,
sfa_torch_cpu,
sfb_torch_cpu,
c_torch_gpu,
) = create_tensors(
l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, scale_dtype
)
return testing.JitArguments(
a_tensor,
b_tensor,
c_tensor,
sfa_tensor,
sfb_tensor,
current_stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
+ sfa_torch_cpu.numel() * sfa_torch_cpu.element_size()
+ sfb_torch_cpu.numel() * sfb_torch_cpu.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
parser = argparse.ArgumentParser(
description="Example of Dense Persistent GEMM on Blackwell."
)
parser.add_argument(
"--mnkl",
type=parse_comma_separated_ints,
default=(256, 256, 512, 1),
help="mnkl dimensions (comma-separated)",
)
parser.add_argument(
"--mma_tiler_mn",
type=parse_comma_separated_ints,
default=(128, 128),
help="Mma tile shape (comma-separated)",
)
parser.add_argument(
"--cluster_shape_mn",
type=parse_comma_separated_ints,
default=(1, 1),
help="Cluster shape (comma-separated)",
)
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN)
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--scale_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument(
"--use_2cta_instrs",
action="store_true",
help="Enable 2CTA MMA instructions feature",
)
parser.add_argument("--a_major", choices=["k"], type=str, default="k")
parser.add_argument("--b_major", choices=["k"], type=str, default="k")
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2", action="store_true", default=False, help="Use cold L2"
)
args = parser.parse_args()
if len(args.mnkl) != 4:
parser.error("--mnkl must contain exactly 4 values")
if len(args.mma_tiler_mn) != 2:
parser.error("--mma_tiler_mn must contain exactly 2 values")
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
run(
args.mnkl,
args.ab_dtype,
args.c_dtype,
args.acc_dtype,
args.scale_dtype,
args.a_major,
args.b_major,
args.c_major,
args.mma_tiler_mn,
args.cluster_shape_mn,
args.use_2cta_instrs,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")