Files
cutlass/examples/python/CuTeDSL/blackwell/mixed_input_gemm.py
Junkai-Wu b1d6e2c9b3 v4.3 update. (#2709)
* v4.3 update.

* Update the cute_dsl_api changelog's doc link

* Update version to 4.3.0

* Update the example link

* Update doc to encourage user to install DSL from requirements.txt

---------

Co-authored-by: Larry Wu <larwu@nvidia.com>
2025-10-21 14:26:30 -04:00

3114 lines
123 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from enum import Enum, auto
from math import log2, ceil
from typing import Optional, Union
import torch
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.cute.testing as testing
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.runtime import from_dlpack
"""
A mixed-input GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL.
This example demonstrates an implementation of mixed-input GEMM using a TMA plus Blackwell SM100 TensorCore
warp-specialized persistent kernel.
The inputs A and B have different data types. In this example, it's assumed that A is the narrow-precision tensor
and B holds data with a wider precision.
MMA will work in the wide precision of tensor B and tensor A will be transformed to the wide precision of tensor B
following 1 of the 2 possible modes as follows:
1. convert-only mode:
C = type_convert(A) x B
In convert-only mode, tensor A is directly converted to the wide precision of tensor B.
2. convert-scale mode:
C = (type_convert(A) * scale) x B
In convert-scale mode, tensor A is first converted to the wide precision of tensor B and then scaled by the scale tensor.
The scale tensor is in the same precision as tensor B.
The mode is determined by tensor A's data type as follows:
- if tensor A is in int8 or uint8, convert-only mode is used.
- if tensor A is in int4, convert-scale mode is used.
The output tensor C could have the same precision as tensor B or fp32.
To run this example:
.. code-block:: bash
python examples/blackwell/mixed_input_gemm.py \
--a_dtype Int8 --b_dtype BFloat16 \
--scale_granularity_m 0 --scale_granularity_k 0 \
--c_dtype BFloat16 --acc_dtype Float32 \
--mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \
--mnkl 256,512,8192,1
Input A and B have int8 and bf16 data types, respectively. The Blackwell tcgen05 MMA tile shape
is specified as (128,128,64) and the cluster shape is (1,1). The MMA accumulator and output data type
are set as fp32 and bf16, respectively. As tensor A is int8, convert-only mode is used.
scale_granularity_m and scale_granularity_k are set as 0 for convert-only mode.
Here is an example of running convert-scale mode:
.. code-block:: bash
python examples/blackwell/mixed_input_gemm.py \
--a_dtype Int4 --b_dtype BFloat16 \
--scale_granularity_m 1 --scale_granularity_k 256 \
--c_dtype BFloat16 --acc_dtype Float32 \
--mma_tiler_mnk 256,128,128 --cluster_shape_mn 2,1 \
--use_2cta_instrs --use_tma_store \
--mnkl 1024,8192,6144,16
Input A and B have int4 and bf16 data types, respectively. The scale granularity is set as (1,256),
which means each element along the m mode of tensor A has its own scale element and 256 contiguous elements
along the k mode share the same scale element. There is no scale reuse along the L mode. If the GEMM shape is
(M, N, K, L), then the scale tensor shape is (M // scale_granularity_m, K // scale_granularity_k, L),
which is (1024, 6144/256, 16) in this example.
The Blackwell tcgen05 MMA tile shape is specified as (256,128,128) and tcgen05 2CTA feature is enabled.
The cluster shape is (2,1). The MMA accumulator and output data type are set as fp32 and bf16, respectively.
As tensor A is int4, the convert-scale mode is used.
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/blackwell/mixed_input_gemm.py \
--a_dtype Int8 --b_dtype BFloat16 \
--scale_granularity_m 0 --scale_granularity_k 0 \
--c_dtype BFloat16 --acc_dtype Float32 \
--mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \
--mnkl 256,512,8192,1 \
--warmup_iterations 1 --iterations 10 --skip_ref_check
Besides the requirements from the Blackwell dense GEMM example, there are some constraints for this example:
* The narrow-precision is constrained to be int8, uint8, or int4 and the other data type is bf16 or f16.
* Output data types could only be fp16, bf16, or fp32.
* The scale_granularity_m must be 1 currently.
* The scale_granularity_k must be a multiple of mma_tiler_k and also be divisible by gemm_k.
* The scale tensor must be in m-major mode.
* OOB tiles are not allowed when TMA store is disabled
"""
class TransformMode(Enum):
"""
An enumeration for the possible transform modes of a mixed-input GEMM.
"""
ConvertOnly = auto()
ConvertScale = auto()
class MixedInputGemmKernel:
"""
Mixed-input GEMM kernel for NVIDIA Blackwell SM100 architecture.
This kernel supports GEMM operations where input tensors A and B have different
data types, with tensor A being transformed to the precision of tensor B before
matrix multiplication.
:param scale_granularity_m: Number of elements sharing the same scale factor along the M mode
:type scale_granularity_m: int
:param scale_granularity_k: Number of elements sharing the same scale factor along the K mode
:type scale_granularity_k: int
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
:type use_2cta_instrs: bool
:param mma_tiler_mnk: Shape of the Matrix Multiply-Accumulate (MMA) tile (M, N, K)
:type mma_tiler_mnk: tuple[int, int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: tuple[int, int]
:param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results
:type use_tma_store: bool
"""
def __init__(
self,
scale_granularity_m: int,
scale_granularity_k: int,
acc_dtype: type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_tma_store: bool,
):
"""
Initializes the mixed-input GEMM kernel with a specified configuration.
"""
# Scale granularity defines how many elements share the same scale factor
# along the M and K modes.
self.scale_granularity_m = scale_granularity_m
self.scale_granularity_k = scale_granularity_k
# Set transform mode
if cutlass.const_expr(
self.scale_granularity_m == 0 and self.scale_granularity_k == 0
):
self.scale_mode = TransformMode.ConvertOnly
else:
self.scale_mode = TransformMode.ConvertScale
self.acc_dtype = acc_dtype
self.use_2cta_instrs = use_2cta_instrs
self.cluster_shape_mn = cluster_shape_mn
self.mma_tiler = mma_tiler_mnk
self.use_tma_store = use_tma_store
self.cta_group = (
tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
)
# Set specialized warp ids
self.epilog_warp_id = (
0,
1,
2,
3,
)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.scale_tma_warp_id = 6
# 4 warps to do the transformation
self.transform_warp_id = (
8,
9,
10,
11,
)
self.threads_per_cta = 32 * (
max(
(
self.mma_warp_id,
self.tma_warp_id,
self.scale_tma_warp_id,
*self.epilog_warp_id,
*self.transform_warp_id,
)
)
+ 1
)
# Set barrier id for cta sync, epilogue sync, tmem ptr sync, and transform sync
self.epilog_sync_barrier = pipeline.NamedBarrier(
1, 32 * len(self.epilog_warp_id)
)
self.tmem_ptr_sync_barrier = pipeline.NamedBarrier(2, self.threads_per_cta)
self.transform_sync_barrier = pipeline.NamedBarrier(
3, 32 * len(self.transform_warp_id)
)
self.cta_sync_barrier = pipeline.NamedBarrier(4, self.threads_per_cta)
self.smem_buffer_align_bytes = 1024
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
(data types, leading dimensions) and kernel settings:
- Deduce where the transformed A tensor is stored
- Configuring tiled MMA
- Computing MMA/cluster/tile shapes
- Computing cluster layout
- Computing multicast CTAs for A/B
- Computing epilogue subtile
- Setting up A/scale/B/C stage counts in shared memory
- Setting up transformed A stage count in shared memory or tensor memory
- Computing A/transformed A/scale/B/C memory layout
- Computing tensor memory allocation columns
"""
# Deduce where the transformed A tensor is stored, shared memory(SMEM) or tensor memory(TMEM)
self.transform_a_source = self._get_transform_a_source(self.a_major_mode)
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.mma_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
self.transform_a_source,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
if cutlass.const_expr(self.use_tma_store):
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
else:
self.epi_tile = self.cta_tile_shape_mnk[:2]
# Compute tensor memory(TMEM) columns and stages for each pipeline
(
self.num_load2trans_stage,
self.num_scale_load2trans_stage,
self.num_trans2mma_stage,
self.num_acc_stage,
self.num_c_stage,
self.num_acc_tmem_cols,
self.num_a_tmem_cols,
) = self._compute_stages_and_tmem_cols(
tiled_mma,
self.mma_tiler,
self.cta_tile_shape_mnk,
self.epi_tile,
self.a_dtype,
self.b_dtype,
self.c_dtype,
self.c_layout,
self.transform_a_source,
self.scale_granularity_m,
self.scale_granularity_k,
self.smem_buffer_align_bytes,
self.use_tma_store,
self.scale_mode,
)
# Ensure load2trans and trans2mma pipelines share same stage count,
# so we can use same pipeline stage index to slice both A and B buffers
if cutlass.const_expr(self.num_load2trans_stage != self.num_trans2mma_stage):
self.num_load2trans_stage = min(
self.num_load2trans_stage, self.num_trans2mma_stage
)
self.num_trans2mma_stage = self.num_load2trans_stage
# Align TMEM columns for allocation
# TMEM allocation requires power-of-2 column alignment
# and must meet minimum allocation requirements
self.num_tmem_alloc_cols = MixedInputGemmKernel.align_up(
self.num_acc_tmem_cols + self.num_a_tmem_cols,
cute.arch.SM100_TMEM_MIN_ALLOC_COLUMNS,
)
self.num_tmem_alloc_cols = 2 ** (ceil(log2(self.num_tmem_alloc_cols)))
# Get smem layout for C tensor when TMA store is enabled
self.c_smem_layout_staged = (
sm100_utils.make_smem_layout_epi(
self.c_dtype,
self.c_layout,
self.epi_tile,
self.num_c_stage,
)
if self.use_tma_store
else None
)
# Get smem layout for A, transformed A, and B
(
self.smem_layout_a,
self.smem_layout_a_transform,
self.smem_layout_b,
) = self._compute_smem_layout(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.num_load2trans_stage,
self.num_trans2mma_stage,
)
# Get smem layout for scale tensor
self.smem_layout_scale_per_stage = None
self.smem_layout_scale = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Get smem layout for scale tensor
(
self.smem_layout_scale_per_stage,
self.smem_layout_scale,
) = self.get_smem_layout_scale()
def _validate_inputs(
self,
a: cute.Tensor,
a_scale: Optional[cute.Tensor],
b: cute.Tensor,
c: cute.Tensor,
) -> None:
"""
Validates input tensors and their properties.
:param a: Input tensor A.
:type a: cute.Tensor
:param a_scale: Scale tensor for tensor A (None for ConvertOnly mode).
:type a_scale: Optional[cute.Tensor]
:param b: Input tensor B.
:type b: cute.Tensor
:param c: Output tensor C.
:type c: cute.Tensor
:raises ValueError: If inputs don't meet kernel requirements.
"""
# Validate scale tensor major mode
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
and utils.LayoutEnum.from_tensor(a_scale).mma_major_mode()
!= tcgen05.OperandMajorMode.MN
):
raise ValueError("scale_major_mode should be m-major")
@cute.jit
def __call__(
self,
a: cute.Tensor,
a_scale: Optional[cute.Tensor], # None for ConvertOnly mode
b: cute.Tensor,
c: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
):
"""
Executes the Mixed Input GEMM operation.
This method sets up the kernel parameters, computes the grid size,
defines the shared storage, and launches the kernel.
The execution steps are as follows:
- Setup static attributes before smem/grid/tma computation.
- Setup TMA load/store atoms and tensors.
- Compute grid size with regard to hardware constraints.
- Define shared storage for kernel.
- Launch the kernel synchronously.
:param a: Input tensor A.
:type a: cute.Tensor
:param a_scale: Scale tensor for tensor A (None for ConvertOnly mode).
:type a_scale: Optional[cute.Tensor]
:param b: Input tensor B.
:type b: cute.Tensor
:param c: Output tensor C.
:type c: cute.Tensor
:param max_active_clusters: Maximum number of active clusters to launch.
:type max_active_clusters: cutlass.Constexpr
:param stream: CUDA stream to launch the kernel on.
:type stream: cuda.CUstream
"""
self.a_dtype: type[cutlass.Numeric] = a.element_type
self.a_scale_dtype: type[cutlass.Numeric] = (
a_scale.element_type
if self.scale_mode is TransformMode.ConvertScale
else None
)
self.b_dtype: type[cutlass.Numeric] = b.element_type
self.c_dtype: type[cutlass.Numeric] = c.element_type
self.mma_dtype = self.b_dtype
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
self.scale_major_mode = (
utils.LayoutEnum.from_tensor(a_scale).mma_major_mode()
if self.scale_mode is TransformMode.ConvertScale
else None
)
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(c)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Get gmem layout for scale tensor
self.gmem_layout_scale = self.get_gmem_layout_scale(a.shape)
# Validate inputs
self._validate_inputs(a, a_scale, b, c)
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.mma_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
self.transform_a_source,
)
# Set up gmem copy atoms for A, scale, and B
a_op = self._get_tma_atom_kind(self.is_a_mcast, self.use_2cta_instrs, False)
b_op = self._get_tma_atom_kind(self.is_b_mcast, self.use_2cta_instrs, True)
a_scale_op = a_op
# Deduce TMA copy atom and TMA tensor for A, scale, and B
smem_layout_a_per_stage = cute.slice_(self.smem_layout_a, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
smem_layout_a_per_stage,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
),
)
tma_atom_scale, tma_tensor_scale = None, None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Partition smem layout for scale tensor to make it compatible with TMA atom
smem_layout_for_tma_atom = cute.get(
tiled_mma._thrfrg_A(self.smem_layout_scale_per_stage.outer), mode=[1]
)
# ((MMA_M, MMA_K), REST_M, REST_K)
smem_layout_for_tma_atom = cute.dice(
smem_layout_for_tma_atom,
(1, (1,) * cute.rank(self.smem_layout_scale_per_stage.outer)),
)
tma_atom_scale, tma_tensor_scale = cute.nvgpu.make_tiled_tma_atom_A(
a_scale_op,
cute.make_tensor(a_scale.iterator, self.gmem_layout_scale),
smem_layout_for_tma_atom,
# (SCALE_M, 1, SCALE_K)
(self.scale_tile_shape[0], 1, self.scale_tile_shape[1]),
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32
if a_scale.element_type is cutlass.Float32
else None
),
)
smem_layout_b_per_stage = cute.slice_(self.smem_layout_b, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
smem_layout_b_per_stage,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
),
)
# Calculate copy size for tensor A, B, and scale
a_copy_size = cute.size_in_bytes(self.a_dtype, smem_layout_a_per_stage)
b_copy_size = cute.size_in_bytes(self.b_dtype, smem_layout_b_per_stage)
a_scale_copy_size = (
cute.size_in_bytes(self.a_scale_dtype, self.smem_layout_scale_per_stage)
if self.scale_mode is TransformMode.ConvertScale
else 0
)
self.num_tma_load_bytes_a = a_copy_size
self.num_tma_load_bytes_b = b_copy_size * cute.size(tiled_mma.thr_id.shape)
self.num_tma_load_bytes_scale = a_scale_copy_size
self.tile_sched_params, grid = self._compute_grid(
c,
self.cta_tile_shape_mnk,
self.cluster_shape_mn,
max_active_clusters,
)
tma_atom_c = None
tma_tensor_c = None
c_smem_size = 0
if cutlass.const_expr(self.use_tma_store):
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
self.epi_tile,
)
c_smem_size = cute.cosize(self.c_smem_layout_staged.outer)
# Shared memory structure
a_smem_size = cute.cosize(self.smem_layout_a.outer)
b_smem_size = cute.cosize(self.smem_layout_b.outer)
a_transform_smem_size = (
cute.cosize(self.smem_layout_a_transform.outer)
if self.transform_a_source == tcgen05.OperandSource.SMEM
else 0
)
a_scale_smem_size = (
cute.cosize(self.smem_layout_scale.outer)
if self.scale_mode is TransformMode.ConvertScale
else 0
)
@cute.struct
class SharedStorage:
a_load2trans_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
a_load2trans_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
a_scale_load2trans_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_scale_load2trans_stage
]
a_scale_load2trans_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_scale_load2trans_stage
]
a_trans2mma_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_trans2mma_stage
]
a_trans2mma_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_trans2mma_stage
]
b_load2mma_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
b_load2mma_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
# Tensor buffers
# (EPI_TILE_M, EPI_TILE_N, STAGE)
smem_C: cute.struct.Align[
cute.struct.MemRange[self.c_dtype, c_smem_size],
self.smem_buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
smem_A: cute.struct.Align[
cute.struct.MemRange[self.a_dtype, a_smem_size],
self.smem_buffer_align_bytes,
]
# (MMA, MMA_N, MMA_K, STAGE)
smem_B: cute.struct.Align[
cute.struct.MemRange[self.b_dtype, b_smem_size],
self.smem_buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
smem_A_transform: cute.struct.Align[
cute.struct.MemRange[self.mma_dtype, a_transform_smem_size],
self.smem_buffer_align_bytes,
]
# (MMA, MMA_M_SCALE, MMA_K_SCALE, STAGE)
smem_A_scale: cute.struct.Align[
cute.struct.MemRange[self.mma_dtype, a_scale_smem_size],
self.smem_buffer_align_bytes,
]
self.shared_storage = SharedStorage
# Launch kernel
self.kernel(
tiled_mma,
tma_atom_a,
tma_tensor_a,
tma_atom_scale,
tma_tensor_scale,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c if self.use_tma_store else c,
self.cluster_layout_vmnk,
self.smem_layout_a,
self.smem_layout_scale,
self.smem_layout_a_transform,
self.smem_layout_b,
self.c_smem_layout_staged,
self.epi_tile,
self.tile_sched_params,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
)
return
# GPU device kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_s: Optional[cute.CopyAtom],
mS_mkl: Optional[cute.Tensor],
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: Optional[cute.CopyAtom],
mC_mnl: cute.Tensor,
cluster_layout_vmnk: cute.Layout,
a_smem_layout: cute.ComposedLayout,
scale_smem_layout: cute.ComposedLayout,
a_smem_layout_transform: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
c_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
tile_sched_params: utils.PersistentTileSchedulerParams,
):
"""
GPU device kernel performing the Persistent Mixed-Input GEMM computation.
"""
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
# Prefetch TMA descriptors
if warp_idx == self.epilog_warp_id[0]:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
cpasync.prefetch_descriptor(tma_atom_s)
if cutlass.const_expr(self.use_tma_store):
cpasync.prefetch_descriptor(tma_atom_c)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
bidx, bidy, bidz = cute.arch.block_idx()
# Compute how many k_tiles share the same scale
num_k_tiles_per_scale = self.scale_granularity_k // self.cta_tile_shape_mnk[2]
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
cta_rank_in_cluster
)
tidx, _, _ = cute.arch.thread_idx()
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
# Initialize load2transform pipeline, which tracks the dependencies between TMA's loading
# of A and B, and the transformation of A and MMA's consumption
transform_thread_idx = (
tidx - 32 * self.transform_warp_id[0]
if tidx >= 32 * self.transform_warp_id[0]
else tidx
)
a_load2trans_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=storage.a_load2trans_full_mbar_ptr.data_ptr(),
num_stages=self.num_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.num_mcast_ctas_a * len(self.transform_warp_id),
),
tx_count=self.num_tma_load_bytes_a,
cta_layout_vmnk=cluster_layout_vmnk,
tidx=transform_thread_idx,
mcast_mode_mn=(1, 0), # multicast for A will only happen on the M-mode
)
# Initialize scale_load2trans pipeline, which tracks the dependencies between TMA's loading
# of scale, and the transformation of A
scale_load2trans_pipeline = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
num_producers_a_scale = self.num_mcast_ctas_a
scale_load2trans_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=storage.a_scale_load2trans_full_mbar_ptr.data_ptr(),
num_stages=self.num_scale_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
num_producers_a_scale
* len(self.transform_warp_id)
* num_k_tiles_per_scale,
),
tx_count=self.num_tma_load_bytes_scale,
cta_layout_vmnk=cluster_layout_vmnk,
tidx=transform_thread_idx,
mcast_mode_mn=(
1,
0,
), # multicast for scale_a will only happen on the M-mode
)
# Initialize transform2mma pipeline, which tracks the dependencies between the transformation
# of A and MMA's consumption of transformed A
cta_v_size = cute.size(cluster_layout_vmnk, mode=[0])
trans2mma_pipeline = pipeline.PipelineAsyncUmma.create(
barrier_storage=storage.a_trans2mma_full_mbar_ptr.data_ptr(),
num_stages=self.num_trans2mma_stage,
producer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.transform_warp_id) * cta_v_size,
),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
cta_layout_vmnk=cluster_layout_vmnk,
)
# Initialize pipeline for tensor B load to MMA
# MMA warp informs TMA warp to proceed to load next tile of B tensor
b_load2mma_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.b_load2mma_full_mbar_ptr.data_ptr(),
num_stages=self.num_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.num_mcast_ctas_b
),
tx_count=self.num_tma_load_bytes_b,
cta_layout_vmnk=cluster_layout_vmnk,
mcast_mode_mn=(0, 1), # multicast for B will only happen on the N-mode
)
# Initialize accumulator pipeline, which tracks the dependencies between
# MMA's computation of accumulators and epilogue warps' consumption of accumulators
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, cta_v_size * len(self.epilog_warp_id)
),
cta_layout_vmnk=cluster_layout_vmnk,
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
barrier_for_retrieve=self.tmem_ptr_sync_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
)
# Cluster arrive after barrier init
if cutlass.const_expr(cute.size(self.cluster_shape_mn) > 1):
cute.arch.cluster_arrive_relaxed()
# Setup smem tensor A/scale/B/C
sC = (
storage.smem_C.get_tensor(
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
)
if self.use_tma_store
else None
)
sA_input = storage.smem_A.get_tensor(
a_smem_layout.outer, swizzle=a_smem_layout.inner
)
sS_input = (
storage.smem_A_scale.get_tensor(
scale_smem_layout.outer, swizzle=scale_smem_layout.inner
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
sB_input = storage.smem_B.get_tensor(
b_smem_layout.outer, swizzle=b_smem_layout.inner
)
sA_transform = None
# Get smem tensor for transformed A when transform_a_source is SMEM
if cutlass.const_expr(self.transform_a_source == tcgen05.OperandSource.SMEM):
sA_transform = storage.smem_A_transform.get_tensor(
a_smem_layout_transform.outer, swizzle=a_smem_layout_transform.inner
)
# Compute multicast mask for A/B buffer full
a_full_mcast_mask = None
b_full_mcast_mask = None
s_full_mcast_mask = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
# scale tensor share the same multicast mask with A tensor
s_full_mcast_mask = a_full_mcast_mask
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
)
# local_tile partition global tensors
# (bM, bK, loopM, loopK, loopL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bM, bK, loopM, loopK, loopL)
gS_mkl = (
cute.local_tile(
mS_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
# (bN, bK, loopN, loopK, loopL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, loopM, loopN, loopL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition global tensor for TiledMMA_A/B/C
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
tCgS = (
thr_mma.partition_A(gS_mkl)
if self.scale_mode is TransformMode.ConvertScale
else None
)
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
tCgC = thr_mma.partition_C(gC_mnl)
# Setup copy atom to load A from shared memory for further transformation
copy_atom_a_input = (
cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), self.a_dtype, num_bits_per_copy=32
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
a_smem_shape = tiled_mma.partition_shape_A(
cute.dice(self.mma_tiler, (1, None, 1))
)
# Setup copy atom to store transformed A into tensor memory or shared memory
copy_atom_a_transform = self._get_copy_atom_a_transform(
self.mma_dtype,
self.use_2cta_instrs,
self.transform_a_source,
a_smem_shape,
self.a_dtype,
)
# Partition global/shared tensor for TMA load A/B
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
a_cta_layout,
cute.group_modes(sA_input, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
tCsS = None
tSsS = None
tSgS = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# (MMA, MMA_M, MMA_K, STAGE)
tCsS = thr_mma.partition_A(sS_input)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tSsS, tSgS = self.scale_tma_partition(
tCsS,
tCgS,
tma_atom_s,
block_in_cluster_coord_vmnk,
a_cta_layout,
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
b_cta_layout,
cute.group_modes(sB_input, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
# (MMA, MMA_N, MMA_K, STAGE)
tCrB = tiled_mma.make_fragment_B(sB_input)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
# Cluster wait before TMEM alloc and ensure pipelines are ready
if cutlass.const_expr(cute.size(self.cluster_shape_mn) > 1):
cute.arch.cluster_wait()
else:
self.cta_sync_barrier.arrive_and_wait()
# TMEM allocation
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
# Get the pointer to the TMEM buffer
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
accumulators = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCrA = None
if cutlass.const_expr(self.transform_a_source == tcgen05.OperandSource.TMEM):
tmem_ptr_transform = cute.recast_ptr(
accumulators.iterator + self.num_acc_tmem_cols, dtype=self.mma_dtype
)
tCrA = cute.make_tensor(
tmem_ptr_transform,
tiled_mma.make_fragment_A(a_smem_layout_transform.outer).layout,
)
else:
tCrA = tiled_mma.make_fragment_A(sA_transform)
# Specialized TMA load warp for A/B tensor
if warp_idx == self.tma_warp_id:
# Persistent tile scheduling loop
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
a_load2trans_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_load2trans_stage
)
b_load2mma_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_load2trans_stage
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
a_load2trans_producer_state.reset_count()
peek_load2trans_empty_status = cutlass.Boolean(1)
if a_load2trans_producer_state.count < k_tile_cnt:
peek_load2trans_empty_status = (
a_load2trans_pipeline.producer_try_acquire(
a_load2trans_producer_state
)
)
b_load2mma_producer_state.reset_count()
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
a_load2trans_pipeline.producer_acquire(
a_load2trans_producer_state, peek_load2trans_empty_status
)
b_load2mma_pipeline.producer_acquire(b_load2mma_producer_state)
# TMA load A/B
cute.copy(
tma_atom_a,
tAgA_slice[(None, a_load2trans_producer_state.count)],
tAsA[(None, a_load2trans_producer_state.index)],
tma_bar_ptr=a_load2trans_pipeline.producer_get_barrier(
a_load2trans_producer_state
),
mcast_mask=a_full_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, b_load2mma_producer_state.count)],
tBsB[(None, b_load2mma_producer_state.index)],
tma_bar_ptr=b_load2mma_pipeline.producer_get_barrier(
b_load2mma_producer_state
),
mcast_mask=b_full_mcast_mask,
)
a_load2trans_pipeline.producer_commit(a_load2trans_producer_state)
b_load2mma_pipeline.producer_commit(b_load2mma_producer_state)
a_load2trans_producer_state.advance()
b_load2mma_producer_state.advance()
if a_load2trans_producer_state.count < k_tile_cnt:
peek_load2trans_empty_status = (
a_load2trans_pipeline.producer_try_acquire(
a_load2trans_producer_state
)
)
# Advance to next tile
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
# Wait A/B buffer empty
a_load2trans_pipeline.producer_tail(a_load2trans_producer_state)
b_load2mma_pipeline.producer_tail(b_load2mma_producer_state)
# Specialized TMA load for scale tensor
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
if warp_idx == self.scale_tma_warp_id:
# Persistent tile scheduling loop
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
scale_load2trans_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_scale_load2trans_stage
)
scale_k_tile_cnt = cute.size(mS_mkl.layout.shape[1][1])
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
# ((atom_v, rest_v), RestK)
tSgS_slice = tSgS[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
# Filter zeros in rest mode
rest_filtered = cute.filter_zeros(tSgS_slice[(0, None)].layout)
tSgS_slice_filtered = cute.make_tensor(
tSgS_slice.iterator,
cute.make_layout(
(tSgS_slice.layout[0].shape, rest_filtered.shape),
stride=(tSgS_slice.layout[0].stride, rest_filtered.stride),
),
)
scale_load2trans_producer_state.reset_count()
peek_scale_load2trans_empty_status = cutlass.Boolean(1)
if scale_load2trans_producer_state.count < scale_k_tile_cnt:
peek_scale_load2trans_empty_status = (
scale_load2trans_pipeline.producer_try_acquire(
scale_load2trans_producer_state
)
)
for k_tile in cutlass.range(0, scale_k_tile_cnt, 1, unroll=1):
scale_load2trans_pipeline.producer_acquire(
scale_load2trans_producer_state,
peek_scale_load2trans_empty_status,
)
# TMA load scale
cute.copy(
tma_atom_s,
tSgS_slice_filtered[
(None, scale_load2trans_producer_state.count)
],
tSsS[(None, scale_load2trans_producer_state.index)],
tma_bar_ptr=scale_load2trans_pipeline.producer_get_barrier(
scale_load2trans_producer_state
),
mcast_mask=s_full_mcast_mask,
)
scale_load2trans_producer_state.advance()
peek_scale_load2trans_empty_status = cutlass.Boolean(1)
if scale_load2trans_producer_state.count < scale_k_tile_cnt:
peek_scale_load2trans_empty_status = (
scale_load2trans_pipeline.producer_try_acquire(
scale_load2trans_producer_state
)
)
# Advance to next tile
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
# Wait scale buffer empty
scale_load2trans_pipeline.producer_tail(scale_load2trans_producer_state)
# Specialized transform warps
if warp_idx >= self.transform_warp_id[0]:
transform_local_tidx = tidx - 32 * self.transform_warp_id[0]
# Partition tensors for transform input and output and set up the copy atom
# used for loading and storing transformed A tensor
(
src_copy_a,
dst_copy_a,
tAsA_input,
tAsA_transform,
) = self.transform_partition(
self.transform_a_source,
self.scale_mode,
copy_atom_a_input,
copy_atom_a_transform,
sA_input,
(
tCrA
if self.transform_a_source == tcgen05.OperandSource.TMEM
else sA_transform
),
transform_local_tidx,
)
# make fragment for input A and transformed A
tArA = cute.make_rmem_tensor(
cute.select(tAsA_input.layout, mode=[0, 1, 2, 3]).shape,
dtype=tAsA_input.element_type,
)
tArA_transform = cute.make_rmem_tensor(
cute.select(tAsA_input.layout, mode=[0, 1, 2, 3]).shape,
dtype=self.mma_dtype,
)
# Partition scale tensor
smem_thr_copy_S = None
tSsS_trans = None
tSrS_copy = None
tSrS = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS = self.scale_partition(
src_copy_a, tCsS, transform_local_tidx, self.mma_dtype
)
assert cute.size(tSrS, mode=[0]) == cute.size(tArA, mode=[0]), (
"tSrS and tArA have different leading dimension"
)
assert cute.size(tSrS) == cute.size(tArA), (
"tSrS and tArA have different shape"
)
# Make all modes except mode0 into loops
tArA_load = cute.group_modes(tArA, 1, cute.rank(tArA))
tSrS_load = (
cute.group_modes(tSrS, 1, cute.rank(tSrS))
if self.scale_mode is TransformMode.ConvertScale
else None
)
tArA_transform_store = cute.group_modes(
tArA_transform, 1, cute.rank(tArA_transform)
)
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
a_load2trans_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer,
self.num_load2trans_stage,
)
scale_load2trans_consumer_state = (
pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer,
self.num_scale_load2trans_stage,
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
trans2mma_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer,
self.num_trans2mma_stage,
)
while work_tile.is_valid_tile:
a_load2trans_consumer_state.reset_count()
peek_load2trans_full_status = cutlass.Boolean(1)
if a_load2trans_consumer_state.count < k_tile_cnt:
peek_load2trans_full_status = (
a_load2trans_pipeline.consumer_try_wait(
a_load2trans_consumer_state
)
)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
scale_load2trans_consumer_state.reset_count()
trans2mma_producer_state.reset_count()
peek_trans2mma_empty_status = cutlass.Boolean(1)
if trans2mma_producer_state.count < k_tile_cnt:
peek_trans2mma_empty_status = (
trans2mma_pipeline.producer_try_acquire(
trans2mma_producer_state
)
)
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
a_load2trans_pipeline.consumer_wait(
a_load2trans_consumer_state, peek_load2trans_full_status
)
# Load A from shared memory
cute.autovec_copy(
tAsA_input[
(None, None, None, None, a_load2trans_consumer_state.index)
],
tArA,
)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
scale_load2trans_pipeline.consumer_wait(
scale_load2trans_consumer_state
)
trans2mma_pipeline.producer_acquire(
trans2mma_producer_state, peek_trans2mma_empty_status
)
# load scale tensor when needed
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
if k_tile % num_k_tiles_per_scale == 0:
tSsS_slice = tSsS_trans[
(
None,
None,
None,
None,
scale_load2trans_consumer_state.index,
)
]
tSsS_slice_filtered = cute.make_tensor(
tSsS_slice.iterator,
cute.filter_zeros(tSsS_slice.layout),
)
cute.autovec_copy(tSsS_slice_filtered, tSrS_copy)
for idx in cutlass.range_constexpr(cute.size(tArA_load, mode=[1])):
# Load tensor A and convert it to mma dtype
tensor_transformed = (
tArA_load[(None, idx)].load().to(self.mma_dtype)
)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
scale = cute.TensorSSA(
tSrS_load[(None, idx)].load(),
tensor_transformed.shape,
self.mma_dtype,
)
# Apply scale
tensor_transformed = tensor_transformed * scale
tArA_transform_store[(None, idx)].store(tensor_transformed)
# Store transformed A to tensor memory or shared memory
if cutlass.const_expr(dst_copy_a is not None):
cute.copy(
dst_copy_a,
tArA_transform,
tAsA_transform[
(None, None, None, None, trans2mma_producer_state.index)
],
)
else:
cute.autovec_copy(
tArA_transform,
tAsA_transform[
(None, None, None, None, trans2mma_producer_state.index)
],
)
# Ensure all transform threads have finished the copy and reached the fence
self.transform_sync_barrier.arrive_and_wait()
if cutlass.const_expr(
self.transform_a_source == tcgen05.OperandSource.TMEM
):
cute.arch.fence_view_async_tmem_store()
else:
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
# Signal the completion of transformation
trans2mma_pipeline.producer_commit(trans2mma_producer_state)
# Signal the completion of using A and scale tensors
a_load2trans_pipeline.consumer_release(a_load2trans_consumer_state)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
scale_load2trans_pipeline.consumer_release(
scale_load2trans_consumer_state
)
if (k_tile + 1) % num_k_tiles_per_scale == 0:
scale_load2trans_consumer_state.advance()
a_load2trans_consumer_state.advance()
trans2mma_producer_state.advance()
if a_load2trans_consumer_state.count < k_tile_cnt:
peek_load2trans_full_status = (
a_load2trans_pipeline.consumer_try_wait(
a_load2trans_consumer_state
)
)
if trans2mma_producer_state.count < k_tile_cnt:
peek_trans2mma_empty_status = (
trans2mma_pipeline.producer_try_acquire(
trans2mma_producer_state
)
)
# Advance to next tile
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
# Wait a_transform buffer empty
trans2mma_pipeline.producer_tail(trans2mma_producer_state)
# Specialized MMA warp
if warp_idx == self.mma_warp_id:
tCtAcc_base = accumulators
# Persistent tile scheduling loop
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
trans2mma_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_trans2mma_stage
)
b_load2mma_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_load2trans_stage
)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
b_load2mma_consumer_state.reset_count()
trans2mma_consumer_state.reset_count()
peek_trans2mma_full_status = cutlass.Boolean(1)
if is_leader_cta:
if trans2mma_consumer_state.count < k_tile_cnt:
peek_trans2mma_full_status = (
trans2mma_pipeline.consumer_try_wait(
trans2mma_consumer_state
)
)
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
# Mma mainloop
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
trans2mma_pipeline.consumer_wait(
trans2mma_consumer_state, peek_trans2mma_full_status
)
b_load2mma_pipeline.consumer_wait(b_load2mma_consumer_state)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (
None,
None,
kblock_idx,
trans2mma_consumer_state.index,
)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
trans2mma_pipeline.consumer_release(trans2mma_consumer_state)
b_load2mma_pipeline.consumer_release(b_load2mma_consumer_state)
trans2mma_consumer_state.advance()
b_load2mma_consumer_state.advance()
peek_trans2mma_full_status = cutlass.Boolean(1)
if trans2mma_consumer_state.count < k_tile_cnt:
peek_trans2mma_full_status = (
trans2mma_pipeline.consumer_try_wait(
trans2mma_consumer_state
)
)
# Async arrive accumulator buffer full
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
# Advance to next tile
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
# Wait for accumulator buffer empty
acc_pipeline.producer_tail(acc_producer_state)
# Specialized epilogue warps
if warp_idx < self.mma_warp_id:
epi_tidx = tidx
tCtAcc_base = accumulators
# Partition for epilogue
(
tiled_copy_t2r,
tTR_tAcc_base,
tTR_rAcc,
) = self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
)
tTR_rC = None
tiled_copy_r2s = None
simt_atom = None
tRS_rC = None
tRS_sC = None
bSG_sC = None
bSG_gC_partitioned = None
tTR_gC_partitioned = None
if cutlass.const_expr(self.use_tma_store):
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
tiled_copy_t2r, tTR_rC, epi_tidx, sC
)
(
tma_atom_c,
bSG_sC,
bSG_gC_partitioned,
) = self.epilog_gmem_copy_and_partition(
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
)
else:
(
simt_atom,
tTR_rC,
tTR_gC_partitioned,
) = self.epilog_gmem_copy_and_partition(
epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC
)
# Persistent tile scheduling loop
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, (bidx, bidy, bidz), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
c_pipeline = None
if cutlass.const_expr(self.use_tma_store):
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=c_producer_group,
)
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
bSG_gC = None
tTR_gC = None
if cutlass.const_expr(self.use_tma_store):
bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)]
else:
tTR_gC = tTR_gC_partitioned[
(None, None, None, None, None, *mma_tile_coord_mnl)
]
tTR_tAcc = tTR_tAcc_base[
(None, None, None, None, None, acc_consumer_state.index)
]
# Wait for accumulator buffer full
acc_pipeline.consumer_wait(acc_consumer_state)
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
if cutlass.const_expr(self.use_tma_store):
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
else:
tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC))
# Store accumulator to global memory in subtiles
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
for subtile_idx in cutlass.range(subtile_cnt):
# Load accumulator from tensor memory buffer to register
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
if cutlass.const_expr(self.use_tma_store):
# Convert to C type
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
acc_vec = acc_vec.to(self.c_dtype)
tRS_rC.store(acc_vec)
c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
# Store C to shared memory
cute.copy(
tiled_copy_r2s,
tRS_rC,
tRS_sC[(None, None, None, c_buffer)],
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.epilog_sync_barrier.arrive_and_wait()
# TMA store C to global memory
if warp_idx == self.epilog_warp_id[0]:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC[(None, subtile_idx)],
)
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()
else:
# Convert to C type
acc_vec = tTR_rAcc.load()
acc_vec = acc_vec.to(self.c_dtype)
tTR_rC.store(acc_vec)
# Store C to global memory
cute.autovec_copy(
tTR_rC, tTR_gC[(None, None, None, subtile_idx)]
)
# Async arrive accumulator buffer empty
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
# Advance to next tile
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
# Dealloc the tensor memory buffer
tmem.relinquish_alloc_permit()
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(tmem_ptr)
if cutlass.const_expr(self.use_tma_store):
c_pipeline.producer_tail()
def scale_tma_partition(
self,
tCsS: cute.Tensor,
tCgS: cute.Tensor,
tma_atom_s: cute.CopyAtom,
block_in_cluster_coord_vmnk: cute.Coord,
scale_cta_layout: cute.Layout,
) -> tuple[cute.Tensor, cute.Tensor]:
"""
Perform TMA partition for scale tensor.
This method partitions the gobal memory and shared memory buffer for scale tensor for TMA load.
:param tCsS: Input scale shared memory tensor
:type tCsS: cute.Tensor
:param tCgS: Input scale global memory tensor
:type tCgS: cute.Tensor
:param tma_atom_s: TMA copy atom for scale tensor
:type tma_atom_s: cute.CopyAtom
:param block_in_cluster_coord_vmnk: CTA coord in the cluster
:type block_in_cluster_coord_vmnk: cute.Coord
:param scale_cta_layout: Layout of CTA from the view of the scale tensor
:type scale_cta_layout: cute.Layout
:return: A tuple containing (tSsS, tSgS) where:
- tSsS: Partitioned scale tensor in shared memory
- tSgS: Partitioned scale tensor in global memory
:rtype: tuple[cute.Tensor, cute.Tensor]
"""
tSsS, tSgS = cpasync.tma_partition(
tma_atom_s,
block_in_cluster_coord_vmnk[2],
scale_cta_layout,
cute.group_modes(tCsS, 0, 3),
cute.group_modes(tCgS, 0, 3),
)
# Add rest_v mode
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tSsS = cute.make_tensor(
tSsS.iterator,
cute.make_layout(
((tSsS.layout.shape[0], 1), *tSsS.layout.shape[1:]),
stride=(
(tSsS.layout.stride[0], 0),
*tSsS.layout.stride[1:],
),
),
)
tSgS = cute.make_tensor(
tSgS.iterator,
cute.make_layout(
((tSgS.layout.shape[0], 1), *tSgS.layout.shape[1:]),
stride=(
(tSgS.layout.stride[0], 0),
*tSgS.layout.stride[1:],
),
),
)
return tSsS, tSgS
def transform_partition(
self,
transform_a_source: tcgen05.OperandSource,
scale_mode: TransformMode,
copy_atom_a_input: cute.CopyAtom,
copy_atom_a_transform: cute.CopyAtom,
sA_input: cute.Tensor,
A_transform: cute.Tensor,
transform_local_tidx: cutlass.Int32,
) -> tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Partition tensors for transform input and output.
This method sets up the copy atoms and partitions the shared/tensor memory
for the transformation of tensor A.
:param transform_a_source: Where the transformed tensor A is stored (TMEM or SMEM)
:type transform_a_source: tcgen05.OperandSource
:param scale_mode: The transform mode (ConvertOnly or ConvertScale)
:type scale_mode: TransformMode
:param copy_atom_a_input: Copy atom for loading A from shared memory
:type copy_atom_a_input: cute.CopyAtom
:param copy_atom_a_transform: Copy atom for storing transformed A
:type copy_atom_a_transform: cute.CopyAtom
:param sA_input: Input tensor A in shared memory
:type sA_input: cute.Tensor
:param A_transform: Transformed tensor A in tensor or shared memory
:type A_transform: cute.Tensor
:param transform_local_tidx: Local thread index for transformation warps
:type transform_local_tidx: cutlass.Int32
:return: A tuple containing (src_copy_a, dst_copy_a, tAsA_input, tA_transform) where:
- src_copy_a: Tiled copy for source tensor
- dst_copy_a: Tiled copy for destination tensor
- tAsA_input: Partitioned input tensor A
- tA_transform: Partitioned transformed tensor A
:rtype: tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor]
"""
if cutlass.const_expr(transform_a_source == tcgen05.OperandSource.TMEM):
if cutlass.const_expr(
cute.size(A_transform, mode=[0, 0]) == 128
and cute.size(sA_input, mode=[0, 0]) == 64
):
tensor_input = cute.make_tensor(
sA_input.iterator,
cute.logical_product(
sA_input.layout,
((cute.make_layout(2, stride=0), None), None, None, None),
),
)
else:
tensor_input = sA_input
reg2tmem_tiled_copy = tcgen05.make_tmem_copy(
copy_atom_a_transform, A_transform[(None, None, None, 0)]
)
thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(
transform_local_tidx
)
partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input)
partitioned_tensor_transform = thr_reg2tmem_tiled_copy.partition_D(
A_transform
)
src_copy_a = (
cute.make_tiled_copy_S(copy_atom_a_input, reg2tmem_tiled_copy)
if scale_mode is TransformMode.ConvertScale
else None
)
dst_copy_a = reg2tmem_tiled_copy
tAsA_input = partitioned_tensor_input
tA_transform = partitioned_tensor_transform
elif cutlass.const_expr(transform_a_source == tcgen05.OperandSource.SMEM):
# Construct tiled_copy satisfying 8 contiguous elts per copy atom
reg2smem_tiled_copy = cute.make_cotiled_copy(
copy_atom_a_transform,
cute.make_layout((128, 8), stride=(8, 1)),
A_transform[(None, None, None, 0)].layout,
)
thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(
transform_local_tidx
)
partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(sA_input)
partitioned_tensor_transform = thr_reg2smem_tiled_copy.partition_D(
A_transform
)
src_copy_a = (
cute.make_tiled_copy_S(copy_atom_a_input, reg2smem_tiled_copy)
if scale_mode is TransformMode.ConvertScale
else None
)
# auto-vec copy is enough for copy from register to shared memory here
dst_copy_a = None
tAsA_input = partitioned_tensor_input
tA_transform = partitioned_tensor_transform
return src_copy_a, dst_copy_a, tAsA_input, tA_transform
def scale_partition(
self,
src_copy_a: cute.TiledCopy,
tCsS: cute.Tensor,
transform_local_tidx: cutlass.Int32,
mma_dtype: type[cutlass.Numeric],
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]:
"""
Partition the scale tensor for transformation.
This method prepares the copy atom and partitions the shared memory for the scale tensor.
:param src_copy_a: Tiled copy for the source tensor
:type src_copy_a: cute.TiledCopy
:param tCsS: Scale tensor in shared memory
:type tCsS: cute.Tensor
:param transform_local_tidx: Local thread index for transformation warps
:type transform_local_tidx: cutlass.Int32
:param mma_dtype: Data type for the MMA operation
:type mma_dtype: type[cutlass.Numeric]
:return: A tuple containing (smem_thr_copy_S, tSsS_trans, tSrS) where:
- smem_thr_copy_S: Tiled copy for the scale tensor
- tSsS_trans: Partitioned scale tensor for transformation
- tSrS_copy: Register fragment for the scale tensor
- tSrS: view of scale tensor used for transformation computation
:rtype: tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]
"""
smem_thr_copy_S = None
tSsS_trans = None
tSrS = None
# Partition scale tensor
smem_thr_copy_S = src_copy_a.get_slice(transform_local_tidx)
tSsS_trans = smem_thr_copy_S.partition_S(tCsS)
# Construct register fragment for scale tensor
tSsS_layout_per_stage = tSsS_trans[(None, None, None, None, 0)].layout
# tSrS for copy
tSrS_copy = cute.make_rmem_tensor(
cute.filter_zeros(tSsS_layout_per_stage).shape, mma_dtype
)
# tSrS view for transformation computation
tSrS = cute.make_tensor(
tSrS_copy.iterator,
cute.make_layout(
tSsS_layout_per_stage.shape, stride=tSrS_copy.layout.stride
),
)
return smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS
def get_gmem_layout_scale(
self, scale_shape_mkl: tuple[int, int, int]
) -> cute.Layout:
"""
Get the layout of the scale tensor in global memory.
:param scale_shape_mkl: The shape of the scale tensor (M, K, L).
:type scale_shape_mkl: tuple[int, int, int]
:return: The layout of the scale tensor in global memory.
:rtype: cute.Layout
"""
m, k, l = scale_shape_mkl
shape_scale = (
(self.scale_granularity_m, cute.ceil_div(m, self.scale_granularity_m)),
(self.scale_granularity_k, cute.ceil_div(k, self.scale_granularity_k)),
)
if cutlass.const_expr(self.scale_major_mode == tcgen05.OperandMajorMode.MN):
layout_mk = cute.make_layout(
shape_scale,
stride=(
(0, 1),
(0, cute.size(shape_scale[0][1])),
),
)
else:
layout_mk = cute.make_layout(
shape_scale,
stride=(
(0, cute.size(shape_scale[1][1])),
(0, 1),
),
)
return cute.make_layout(
(*layout_mk.shape, l),
stride=(*layout_mk.stride, cute.cosize(layout_mk)),
)
def get_smem_layout_scale(self) -> tuple[cute.ComposedLayout, cute.ComposedLayout]:
"""
Get the layout of the scale tensor in shared memory.
:return: A tuple containing:
- smem_layout_scale_per_stage: Shared memory layout for scale tensor per stage
- smem_layout_scale: Shared memory layout for scale tensor
:rtype: tuple[cute.ComposedLayout, cute.ComposedLayout]
"""
self.scale_tile_shape = (
(
cute.size(self.mma_tiler[0]) // 2
if self.use_2cta_instrs
else cute.size(self.mma_tiler[0])
),
cute.size(self.mma_tiler[2]),
)
size_mn = self.scale_tile_shape[0]
size_k = self.scale_tile_shape[1]
smem_size_mn = (
self.scale_granularity_m if self.scale_granularity_m < size_mn else size_mn
)
smem_size_k = (
self.scale_granularity_k if self.scale_granularity_k < size_k else size_k
)
div_mn = cute.ceil_div(size_mn, smem_size_mn)
div_k = cute.ceil_div(size_k, smem_size_k)
smem_atom_shape = (
(smem_size_mn, div_mn),
(smem_size_k, div_k),
)
if cutlass.const_expr(self.scale_major_mode == tcgen05.OperandMajorMode.MN):
outer_layout = cute.make_layout(
smem_atom_shape,
stride=(
(0, 1),
(0, div_mn),
),
)
else:
outer_layout = cute.make_layout(
smem_atom_shape,
stride=(
(0, div_k),
(0, 1),
),
)
# Apply a trivial swizzle to make it a composed layout, which could be used to construct TMA atom
smem_layout_scale_per_stage = cute.make_composed_layout(
cute.make_swizzle(0, 4, 3), 0, outer_layout
)
assert cute.rank(smem_layout_scale_per_stage) == 2, (
"Scale layout must be rank 2"
)
assert (
cute.size(self.mma_tiler[0])
% cute.size(smem_layout_scale_per_stage.outer[0])
== 0
), "smem_layout_scale_per_stage must equal the tile shape."
assert (
cute.size(self.mma_tiler[2])
% cute.size(smem_layout_scale_per_stage.outer[1])
== 0
), "smem_layout_scale_per_stage must evenly divide tile k shape."
# Shared memory buffer for scale must be at least 128B to satisfy TMA requirement
assert (
cute.size_in_bytes(self.a_scale_dtype, smem_layout_scale_per_stage) >= 128
), "smem size for scale must be at least 128B"
# Scale layout in smem with multiple stages
smem_layout_scale = cute.append(
smem_layout_scale_per_stage,
cute.make_layout(
(self.num_scale_load2trans_stage),
stride=(cute.cosize(smem_layout_scale_per_stage.outer)),
),
)
return smem_layout_scale_per_stage, smem_layout_scale
def epilog_gmem_copy_and_partition(
self,
tidx: cutlass.Int32,
atom: Union[cute.CopyAtom, cute.TiledCopy],
gC_mnl: cute.Tensor,
epi_tile: cute.Tile,
sC: cute.Tensor,
) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
"""
Partitions source and destination tensors for a global memory store.
This method generates a tiled copy for storing results to global memory
and partitions the source (register or shared memory) and destination
(global memory) tensors accordingly. The behavior varies based on whether
TMA store is enabled.
:param tidx: The thread index in epilogue warp groups.
:type tidx: cutlass.Int32
:param atom: The copy atom to be used (TMA or universal).
:type atom: cute.CopyAtom or cute.TiledCopy
:param gC_mnl: The global tensor C.
:type gC_mnl: cute.Tensor
:param epi_tile: The epilogue tiler.
:type epi_tile: cute.Tile
:param sC: The shared memory tensor C.
:return: A tuple containing the appropriate copy atom and partitioned
source and destination tensors for the store operation.
:rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
"""
gC_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
if self.use_tma_store:
tma_atom_c = atom
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
# ((ATOM_V, REST_V), EPI_M, EPI_N)
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
sC_for_tma_partition,
gC_for_tma_partition,
)
return tma_atom_c, bSG_sC, bSG_gC
else:
tiled_copy_t2r = atom
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
tTR_gC = thr_copy_t2r.partition_D(gC_epi)
# (T2R, T2R_M, T2R_N)
tTR_rC = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype
)
simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype)
return simt_atom, tTR_rC, tTR_gC
def epilog_smem_copy_and_partition(
self,
tiled_copy_t2r: cute.TiledCopy,
tTR_rC: cute.Tensor,
tidx: cutlass.Int32,
sC: cute.Tensor,
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Partitions source and destination tensors for a shared memory store.
This method generates a tiled copy for storing results to shared memory
and partitions the source (register) and destination (shared memory)
tensors accordingly.
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy.
:param tTR_rC: The partitioned accumulator tensor.
:param tidx: The thread index in epilogue warp groups.
:param sC: The shared memory tensor to be copied and partitioned.
:return: A tuple containing the tiled copy for the store operation and
the partitioned source and destination tensors.
"""
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
# (R2S, R2S_M, R2S_N)
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
return tiled_copy_r2s, tRS_rC, tRS_sC
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,
tAcc: cute.Tensor,
gC_mnl: cute.Tensor,
epi_tile: cute.Tile,
use_2cta_instrs: Union[cutlass.Boolean, bool],
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Partitions source and destination tensors for a tensor memory load.
This method generates a tiled copy for loading accumulators from tensor
memory and partitions the source (tensor memory) and destination
(register) tensors accordingly.
:param tidx: The thread index in epilogue warp groups.
:param tAcc: The accumulator tensor to be copied and partitioned.
:param gC_mnl: The global tensor C.
:param epi_tile: The epilogue tiler.
:param use_2cta_instrs: Whether use_2cta_instrs is enabled.
:return: A tuple containing the tiled copy for the load operation and
the partitioned source and destination tensors.
"""
# Make tiledCopy for tensor memory load
copy_atom_t2r = sm100_utils.get_tmem_load_op(
self.cta_tile_shape_mnk,
self.c_layout,
self.c_dtype,
self.acc_dtype,
epi_tile,
use_2cta_instrs,
)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
tAcc_epi = cute.flat_divide(
tAcc[((None, None), 0, 0, None)],
epi_tile,
)
# (EPI_TILE_M, EPI_TILE_N)
tiled_copy_t2r = tcgen05.make_tmem_copy(
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
)
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
@staticmethod
def align_up(x: int, align: int) -> int:
"""Align x up to the nearest multiple of align."""
return (x + align - 1) // align * align
@staticmethod
def _compute_stages_and_tmem_cols(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: tuple[int, int, int],
cta_tile_shape_mnk: tuple[int, int, int],
epi_tile: cute.Tile,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
transform_a_source: tcgen05.OperandSource,
scale_granularity_m: int,
scale_granularity_k: int,
smem_buffer_align_bytes: int,
use_tma_store: bool,
scale_mode: TransformMode,
) -> tuple[int, int, int, int, int, int, int]:
"""
Compute pipeline stages and TMEM column allocation configurations.
This method calculates the number of pipeline stages for different operations
(load2trans, trans2mma, accumulator, etc.) and determines TMEM column allocation
based on available memory resources and tile configuration.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mnk: tuple[int, int, int]
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type cta_tile_shape_mnk: tuple[int, int, int]
:param epi_tile: The epilogue tile shape.
:type epi_tile: cute.Tile
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param c_dtype: Data type of operand C.
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout enum of operand C.
:type c_layout: utils.LayoutEnum
:param transform_a_source: The source of the transformed A tensor.
:type transform_a_source: tcgen05.OperandSource
:param scale_granularity_m: The granularity of the scale tensor along the M mode.
:type scale_granularity_m: int
:param scale_granularity_k: The granularity of the scale tensor along the K mode.
:type scale_granularity_k: int
:param smem_buffer_align_bytes: The alignment of the shared memory buffer.
:type smem_buffer_align_bytes: int
:param use_tma_store: Whether TMA store is enabled.
:type use_tma_store: bool
:param scale_mode: The transform mode.
:type scale_mode: TransformMode
:return: A tuple containing the number of stages for:
(load2trans, scale_load2trans, transform2mma, accumulator, c, tmem_acc_cols, tmem_a_cols)
:rtype: tuple[int, int, int, int, int, int, int]
- num_load2trans_stage: Stages for load-to-transform A and B tensors pipeline
- num_scale_load2trans_stage: Stages for scale load-to-transform A tensor pipeline
- num_trans2mma_stage: Stages for transform-to-MMA pipeline
- num_acc_stage: Stages for accumulator-to-epilogue pipeline
- num_c_stage: Stages for epilogue-to-output C pipeline
- num_acc_tmem_cols: TMEM columns for accumulator
- num_a_tmem_cols: TMEM columns for transformed A tensor
"""
# Compute tmem columns required for accumulator
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
tCtAcc_stage1 = tiled_mma.make_fragment_C(cute.append(acc_shape, 1))
num_tmem_acc_col_per_stage = utils.get_num_tmem_alloc_cols(tCtAcc_stage1, True)
# Heuristic to decide the number of stages for accumulator
sm100_tmem_columns = cute.arch.SM100_TMEM_CAPACITY_COLUMNS
accumulator_stage_count = sm100_tmem_columns // num_tmem_acc_col_per_stage
if transform_a_source == tcgen05.OperandSource.TMEM:
if num_tmem_acc_col_per_stage < 128:
accumulator_stage_count = 3
elif num_tmem_acc_col_per_stage < 256:
accumulator_stage_count = 2
else:
accumulator_stage_count = 1
# transformed A in 16bit, thus 1 tmem column could hold 2 elements
num_elts_per_tmem_col = 32 // tiled_mma.op.a_dtype.width
num_tmem_cols_a_per_stage = MixedInputGemmKernel.align_up(
(
cta_tile_shape_mnk[2] // num_elts_per_tmem_col
if transform_a_source == tcgen05.OperandSource.TMEM
else 0
),
4,
)
c_stage_count = 2 if use_tma_store else 0
c_smem_layout_staged_one = (
sm100_utils.make_smem_layout_epi(
c_dtype,
c_layout,
epi_tile,
1,
)
if use_tma_store
else None
)
c_bytes_per_stage = (
cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
if use_tma_store
else 0
)
c_bytes = c_bytes_per_stage * c_stage_count
smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
bytes_per_pipeline_stage = 16
if scale_mode == TransformMode.ConvertOnly:
scale_load2trans_stage_count = 0
a_scale_bytes_per_stage = 0
else:
# Ensure we have 2 buffers for scale tiles needed for 1 CTA tile
a_scale_k_mode = max(cta_tile_shape_mnk[2] // scale_granularity_k, 1)
a_scale_m_mode = max(cta_tile_shape_mnk[0] // scale_granularity_m, 1)
scale_load2trans_stage_count = 2
a_scale_bytes_per_stage = MixedInputGemmKernel.align_up(
cute.size_in_bytes(
tiled_mma.op.a_dtype,
cute.make_layout((a_scale_m_mode, a_scale_k_mode)),
),
smem_buffer_align_bytes,
)
a_scale_bytes = (
a_scale_bytes_per_stage + bytes_per_pipeline_stage
) * scale_load2trans_stage_count
caveout_smem_bytes = (
bytes_per_pipeline_stage * accumulator_stage_count + a_scale_bytes + c_bytes
)
# Compute transform stages if A is in TMEM
num_tmem_acc_cols = MixedInputGemmKernel.align_up(
accumulator_stage_count * num_tmem_acc_col_per_stage, 4
)
transform2mma_stage_count_a_source_tmem_potential = (
(sm100_tmem_columns - num_tmem_acc_cols) // num_tmem_cols_a_per_stage
if transform_a_source == tcgen05.OperandSource.TMEM
else -1
)
if (
transform_a_source == tcgen05.OperandSource.TMEM
and transform2mma_stage_count_a_source_tmem_potential <= 0
):
raise ValueError("Not enough TMEM capacity for selected tile size")
a_load_bytes_per_stage = MixedInputGemmKernel.align_up(
cute.size_in_bytes(
a_dtype,
cute.make_layout((cta_tile_shape_mnk[0], cta_tile_shape_mnk[2])),
),
smem_buffer_align_bytes,
)
b_load_bytes_per_stage = MixedInputGemmKernel.align_up(
cute.size_in_bytes(
b_dtype,
cute.make_layout(
(
cta_tile_shape_mnk[1] // cute.size(tiled_mma.thr_id),
cta_tile_shape_mnk[2],
)
),
),
smem_buffer_align_bytes,
)
ab_load_bytes_per_stage = (
a_load_bytes_per_stage
+ b_load_bytes_per_stage
+ 2 * bytes_per_pipeline_stage
)
a_transform_bytes_per_stage = (
MixedInputGemmKernel.align_up(
cute.size_in_bytes(
tiled_mma.op.a_dtype,
cute.make_layout((cta_tile_shape_mnk[0], cta_tile_shape_mnk[2])),
),
smem_buffer_align_bytes,
)
if transform_a_source == tcgen05.OperandSource.SMEM
else 0
)
a_transform_bytes_per_stage = (
a_transform_bytes_per_stage + bytes_per_pipeline_stage
)
transform2mma_stage_count_a_source_smem_potential = (
smem_capacity - caveout_smem_bytes
) // (ab_load_bytes_per_stage + a_transform_bytes_per_stage)
transform2mma_stage_count = (
min(
transform2mma_stage_count_a_source_tmem_potential,
transform2mma_stage_count_a_source_smem_potential,
)
if transform_a_source == tcgen05.OperandSource.TMEM
else transform2mma_stage_count_a_source_smem_potential
)
load2transform_stage_count = (
smem_capacity
- caveout_smem_bytes
- (transform2mma_stage_count * a_transform_bytes_per_stage)
) // ab_load_bytes_per_stage
if (
load2transform_stage_count < 2
or transform2mma_stage_count < 2
or accumulator_stage_count < 1
):
raise ValueError("Not enough SMEM or TMEM capacity for selected tile size")
num_tmem_a_cols = transform2mma_stage_count * num_tmem_cols_a_per_stage
# Check if we can increase c_stage_count with leftover smem
if use_tma_store:
c_stage_count += (
smem_capacity
- load2transform_stage_count * ab_load_bytes_per_stage
- transform2mma_stage_count * a_transform_bytes_per_stage
- scale_load2trans_stage_count * a_scale_bytes_per_stage
- c_bytes
) // c_bytes_per_stage
return (
load2transform_stage_count,
scale_load2trans_stage_count,
transform2mma_stage_count,
accumulator_stage_count,
c_stage_count,
num_tmem_acc_cols,
num_tmem_a_cols,
)
@staticmethod
def _compute_smem_layout(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: tuple[int, int, int],
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
load2trans_stage_count: int,
trans2mma_stage_count: int,
) -> tuple[
cute.ComposedLayout,
cute.ComposedLayout,
cute.ComposedLayout,
]:
"""
Compute shared memory layouts for tensor A, transformed A and tensor B.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mnk: tuple[int, int, int]
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param load2trans_stage_count: Number of stages for load-to-transform pipeline.
:type load2trans_stage_count: int
:param trans2mma_stage_count: Number of stages for transform-to-MMA pipeline.
:type trans2mma_stage_count: int
:return: A tuple containing:
- smem_layout_a: Shared memory layout for tensor A
- smem_layout_a_transform: Shared memory layout for transformed tensor A
- smem_layout_b: Shared memory layout for tensor B
:rtype: tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
"""
smem_layout_a = sm100_utils.make_smem_layout_a(
tiled_mma,
mma_tiler_mnk,
a_dtype,
load2trans_stage_count,
)
smem_layout_a_transform = sm100_utils.make_smem_layout_a(
tiled_mma,
mma_tiler_mnk,
tiled_mma.op.a_dtype,
trans2mma_stage_count,
)
smem_layout_b = sm100_utils.make_smem_layout_b(
tiled_mma,
mma_tiler_mnk,
b_dtype,
load2trans_stage_count,
)
return (
smem_layout_a,
smem_layout_a_transform,
smem_layout_b,
)
@staticmethod
def _get_transform_a_source(
a_major_mode: tcgen05.OperandMajorMode,
) -> tcgen05.OperandSource:
"""
Determine the operand source for transformed A tensor based on the operand major mode.
"""
if cutlass.const_expr(a_major_mode == tcgen05.OperandMajorMode.K):
return tcgen05.OperandSource.TMEM
else:
return tcgen05.OperandSource.SMEM
@staticmethod
def _get_tma_atom_kind(
mcast: cutlass.Boolean,
use_2cta_instrs: bool,
is_b: bool,
) -> Union[
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
]:
"""
Get the TMA atom kind based on 1) whether it's a multicast operation,
2) whether 2CTA tcgen05.mma instruction is enabled, and
3) whether it's a B tensor
"""
# Not using .2CTA instructions for tensor A as the consumer is threads on different CTAs
cta_group = (
tcgen05.CtaGroup.TWO if (use_2cta_instrs and is_b) else tcgen05.CtaGroup.ONE
)
if cutlass.const_expr(mcast):
return cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group)
return cpasync.CopyBulkTensorTileG2SOp(cta_group)
@staticmethod
def _get_copy_atom_a_transform(
mma_dtype: type[cutlass.Numeric],
use_2cta_instrs: bool,
transform_a_source: tcgen05.OperandSource,
a_smem_shape: cute.Shape,
a_dtype: type[cutlass.Numeric],
) -> cute.CopyAtom:
"""
Determine the copy atom for transformed A tensor based on the operand source and tile size.
"""
if cutlass.const_expr(transform_a_source == tcgen05.OperandSource.TMEM):
if cutlass.const_expr(
cute.size(a_smem_shape[0][0]) == 64 and (not use_2cta_instrs)
):
copy_op_r2t = tcgen05.St16x256bOp(
tcgen05.Repetition(1), tcgen05.Unpack.NONE
)
else:
copy_op_r2t = tcgen05.St32x32bOp(
tcgen05.Repetition(8), tcgen05.Unpack.NONE
)
return cute.make_copy_atom(copy_op_r2t, mma_dtype)
else:
return cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), a_dtype, num_bits_per_copy=32
)
@staticmethod
def _compute_grid(
c: cute.Tensor,
cta_tile_shape_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]:
"""
Use persistent tile scheduler to compute the grid size for the output tensor C.
"""
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
gc = cute.zipped_divide(c, tiler=c_shape)
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
def is_valid_scale_granularity(
scale_granularity_m: int,
scale_granularity_k: int,
a_dtype: type[cutlass.Numeric],
k: int,
mma_tiler_k: int,
) -> bool:
"""
Check if the scale granularity settings are valid for the given data type and problem size.
"""
if a_dtype.width == 8:
# No scale tensor for 8bit data type A
if not (scale_granularity_m == 0 and scale_granularity_k == 0):
return False
elif a_dtype.width == 4:
if scale_granularity_m != 1 or (
scale_granularity_k == 0
or k % scale_granularity_k != 0
or scale_granularity_k % mma_tiler_k != 0
):
return False
return True
def is_valid_tensor_alignment(
m: int,
n: int,
k: int,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
scale_dtype: type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mnk: tuple[int, int, int],
use_2cta_instrs: bool,
cluster_shape_mn: tuple[int, int],
scale_granularity_m: int,
scale_granularity_k: int,
) -> bool:
"""
Check if the tensor alignments are valid for the given problem size and data types.
"""
def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape):
major_mode_idx = 0 if is_mode0_major else 1
num_major_elements = tensor_shape[major_mode_idx]
num_contiguous_elements = 16 * 8 // dtype.width
return num_major_elements % num_contiguous_elements == 0
if not (
check_contiguous_16B_alignment(a_dtype, a_major == "m", (m, k))
and check_contiguous_16B_alignment(b_dtype, b_major == "n", (n, k))
and check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n))
and (
scale_granularity_k == 0
or check_contiguous_16B_alignment(
b_dtype, True, (m, k // scale_granularity_k)
)
)
):
return False
# Check if scale tensor matches the TMA load 128B alignment requirement
cta_tile_shape_mnk = (
mma_tiler_mnk[0] // (2 if use_2cta_instrs else 1),
mma_tiler_mnk[1],
mma_tiler_mnk[2],
)
if (
scale_granularity_m > 0
and (cta_tile_shape_mnk[0] // cluster_shape_mn[1] // scale_granularity_m)
* (scale_dtype.width // 8)
< 128
):
return False
return True
def is_valid_epilog_store_option(
m: int,
n: int,
mma_tiler_mn: tuple[int, int],
use_tma_store: bool,
use_2cta_instrs: bool,
) -> bool:
"""
Check if the epilogue store option is valid for the given problem size.
"""
cta_tile_shape_mn = (
mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
mma_tiler_mn[1],
)
# No OOB tile support when TMA store is disabled
if not use_tma_store:
if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0):
return False
return True
def is_valid_mma_tiler_and_cluster_shape(
mma_tiler: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_2cta_instrs: bool,
) -> bool:
"""
Check if the MMA tiler and cluster shape are valid for the given problem size.
"""
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
return False
if (mma_tiler[0] // (2 if use_2cta_instrs else 1)) not in [64, 128]:
return False
return True
def can_implement(
mnkl: tuple[int, int, int, int],
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
scale_granularity_m: int,
scale_granularity_k: int,
mma_tiler: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_2cta_instrs: bool,
use_tma_store: bool,
) -> bool:
"""
Check if the kernel can be implemented for the given tensor shapes and data types.
"""
m, n, k, l = mnkl
if not MixedInputGemmKernel.is_valid_mma_tiler_and_cluster_shape(
mma_tiler, cluster_shape_mn, use_2cta_instrs
):
return False
if not MixedInputGemmKernel.is_valid_scale_granularity(
scale_granularity_m, scale_granularity_k, a_dtype, k, mma_tiler[2]
):
return False
if not MixedInputGemmKernel.is_valid_tensor_alignment(
m,
n,
k,
a_dtype,
b_dtype,
c_dtype,
b_dtype,
a_major,
b_major,
c_major,
mma_tiler,
use_2cta_instrs,
cluster_shape_mn,
scale_granularity_m,
scale_granularity_k,
):
return False
if not MixedInputGemmKernel.is_valid_epilog_store_option(
m, n, mma_tiler[:2], use_tma_store, use_2cta_instrs
):
return False
return True
def create_i4_tensor_and_scale(
l: int,
m: int,
k: int,
is_m_major: bool,
dtype: type[cutlass.Numeric],
scale_granularity_m: int,
scale_granularity_k: int,
is_dynamic_layout: bool = True,
init_config: tuple = (
cutlass_torch.TensorInitType.RANDOM,
cutlass_torch.RandomInitConfig(min_val=-7, max_val=6),
),
divisibility: int = 16,
transformed_dtype: Optional[type[cutlass.Numeric]] = None,
) -> tuple[
cute.Tensor, torch.Tensor, torch.Tensor, cute.Tensor, torch.Tensor, torch.Tensor
]:
"""
Create quantized 4-bit tensor and corresponding scale tensor.
"""
lb_4b = -8 if dtype == cutlass.Int4 else 0
up_4b = 7 if dtype == cutlass.Int4 else 15
if not (
init_config[0] == cutlass_torch.TensorInitType.RANDOM
or init_config[0] == cutlass_torch.TensorInitType.SCALAR
):
raise ValueError(
"Only random and scalar initialization is supported for 4bit data type"
)
# Construct reference tensor in f32
ref_fp32 = cutlass_torch.matrix(l, m, k, is_m_major, cutlass.Float32, *init_config)
# Generate scale data and perform quantization
num_scales = k // scale_granularity_k
ref = ref_fp32.to(dtype=cutlass_torch.dtype(transformed_dtype)).reshape(
m, num_scales, scale_granularity_k, l
)
# Get elements with maximum absolute value to compute scaling factors
a_max = torch.maximum(ref / up_4b, ref / lb_4b)
a_scales, _ = torch.max(a_max, dim=2, keepdim=True)
a_scale_inv = torch.where(a_scales == 0, 0, 1 / a_scales)
a_quant = ref * a_scale_inv
# Convert values to integer to avoid computation errors
a_quant = a_quant.to(dtype=torch.int32).reshape((m, k, l)).to(dtype=torch.float32)
# Construct A quantized tensor
cute_a_quant_tensor, torch_a_quant_tensor = cutlass_torch.cute_tensor_like(
a_quant, dtype, is_dynamic_layout=is_dynamic_layout, assumed_align=divisibility
)
# Construct cute scale tensor
a_scales = a_scales.random_(-3, 3).reshape((m, num_scales, l))
# Scale tensor is always m-major
a_scales = a_scales.permute(2, 1, 0).contiguous().permute(2, 1, 0).to(device="cuda")
cute_scale_tensor = from_dlpack(a_scales, assumed_align=divisibility)
for i, stride in enumerate(a_scales.stride()):
if stride == 1:
leading_dim = i
break
if is_dynamic_layout:
cute_scale_tensor = cute_scale_tensor.mark_layout_dynamic(
leading_dim=leading_dim
)
return (
cute_a_quant_tensor,
torch_a_quant_tensor,
a_quant.to("cpu"),
cute_scale_tensor,
a_scales,
a_scales.to("cpu"),
)
def get_divisibility(contiguous_dim_size: int, upper_bound: int = 128) -> int:
"""
Calculate the largest power of 2 divisibility factor for memory alignment.
"""
# Check the largest power of 2 factor of contiguous_dim_size
for i in range(int(log2(contiguous_dim_size)), 0, -1):
if contiguous_dim_size % (2**i) == 0:
return min(2**i, upper_bound)
return 1
def create_tensor_a(
l: int,
m: int,
k: int,
a_major: str,
a_dtype: type[cutlass.Numeric],
scale_granularity_m: int = 0,
scale_granularity_k: int = 0,
transformed_dtype: Optional[type[cutlass.Numeric]] = None,
) -> tuple[cute.Tensor, cute.Tensor, torch.Tensor, torch.Tensor]:
"""
Create tensor A and scale tensor.
"""
a_scale_tensor = None
a_scale_torch_cpu = None
if a_dtype in (cutlass.Int4,):
(
a_tensor,
a_torch_gpu,
a_torch_cpu,
a_scale_tensor,
a_scale_torch_gpu,
a_scale_torch_cpu,
) = create_i4_tensor_and_scale(
l,
m,
k,
a_major == "m",
a_dtype,
scale_granularity_m,
scale_granularity_k,
divisibility=get_divisibility(m if a_major == "m" else k),
transformed_dtype=transformed_dtype,
)
else:
a_torch_cpu = cutlass_torch.matrix(
l,
m,
k,
a_major == "m",
a_dtype,
)
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu,
a_dtype,
is_dynamic_layout=True,
assumed_align=get_divisibility(m if a_major == "m" else k),
)
return a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu
def create_tensors(
l: int,
m: int,
n: int,
k: int,
a_major: str,
b_major: str,
c_major: str,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
scale_granularity_m: int = 0,
scale_granularity_k: int = 0,
) -> tuple:
"""
Create all input and output tensors for the mixed-input GEMM.
"""
torch.manual_seed(2025)
a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu = create_tensor_a(
l, m, k, a_major, a_dtype, scale_granularity_m, scale_granularity_k, b_dtype
)
b_torch_cpu = cutlass_torch.matrix(
l,
n,
k,
b_major == "n",
b_dtype,
cutlass_torch.TensorInitType.RANDOM,
cutlass_torch.RandomInitConfig(min_val=-10, max_val=10),
)
c_torch_cpu = cutlass_torch.matrix(
l,
m,
n,
c_major == "m",
c_dtype,
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu,
b_dtype,
is_dynamic_layout=True,
assumed_align=get_divisibility(n if b_major == "n" else k),
)
c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like(
c_torch_cpu,
c_dtype,
is_dynamic_layout=True,
assumed_align=get_divisibility(m if c_major == "m" else n),
)
c_tensor = c_tensor.mark_compact_shape_dynamic(
mode=(0 if c_major == "m" else 1),
stride_order=(2, 1, 0) if c_major == "m" else (2, 0, 1),
divisibility=get_divisibility(m if c_major == "m" else n),
)
return (
a_tensor,
a_scale_tensor,
b_tensor,
c_tensor,
a_torch_cpu,
a_scale_torch_cpu,
b_torch_cpu,
c_torch_gpu,
)
def compare(
a_torch_cpu: torch.Tensor,
b_torch_cpu: torch.Tensor,
a_scale_torch_cpu: Optional[torch.Tensor],
c_torch_gpu: torch.Tensor,
c_dtype: type[cutlass.Numeric],
tolerance: float,
) -> None:
"""
Compare kernel result with reference computation.
"""
kernel_result = c_torch_gpu.cpu()
# Compute reference result
if a_scale_torch_cpu is not None:
scale_shape = a_scale_torch_cpu.shape
a_shape = a_torch_cpu.shape
a_scale_torch_cpu = a_scale_torch_cpu.to(dtype=torch.float32).reshape(
scale_shape[0], scale_shape[1], 1, scale_shape[2]
)
a_torch_cpu = a_torch_cpu.to(dtype=torch.float32).reshape(
a_torch_cpu.shape[0], scale_shape[1], -1, a_torch_cpu.shape[2]
)
a_dequant = a_torch_cpu * a_scale_torch_cpu
ref = torch.einsum(
"mkl,nkl->mnl",
a_dequant.reshape(a_shape),
b_torch_cpu.to(dtype=torch.float32),
)
else:
ref = torch.einsum(
"mkl,nkl->mnl",
a_torch_cpu.to(dtype=torch.float32),
b_torch_cpu.to(dtype=torch.float32),
)
# Convert ref to c_dtype
_, ref_torch_gpu = cutlass_torch.cute_tensor_like(
ref, c_dtype, is_dynamic_layout=True, assumed_align=16
)
ref_result = ref_torch_gpu.cpu()
# Assert close results
torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05)
def run(
mnkl: tuple[int, int, int, int],
scale_granularity_m: int,
scale_granularity_k: int,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
acc_dtype: type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_2cta_instrs: bool,
use_tma_store: bool,
tolerance: float,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
) -> None:
"""
Run the mixed-input GEMM kernel with specified parameters.
This function creates tensors, validates parameters, executes the kernel,
optionally compares results with a reference implementation and reports
kernel execution time.
"""
m, n, k, l = mnkl
if not torch.cuda.is_available():
raise ValueError("CUDA is not available")
# Check if given configuration is supported
if not MixedInputGemmKernel.can_implement(
mnkl,
a_dtype,
b_dtype,
c_dtype,
a_major,
b_major,
c_major,
scale_granularity_m,
scale_granularity_k,
mma_tiler_mnk,
cluster_shape_mn,
use_2cta_instrs,
use_tma_store,
):
raise ValueError("GEMM configuration not supported")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
(
a_tensor,
a_scale_tensor,
b_tensor,
c_tensor,
a_torch_cpu,
a_scale_torch_cpu,
b_torch_cpu,
c_torch_gpu,
) = create_tensors(
l,
m,
n,
k,
a_major,
b_major,
c_major,
a_dtype,
b_dtype,
c_dtype,
scale_granularity_m,
scale_granularity_k,
)
mixed_input_gemm = MixedInputGemmKernel(
scale_granularity_m,
scale_granularity_k,
acc_dtype,
use_2cta_instrs,
mma_tiler_mnk,
cluster_shape_mn,
use_tma_store,
)
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1],
)
compiled_kernel = cute.compile(
mixed_input_gemm,
a_tensor,
a_scale_tensor,
b_tensor,
c_tensor,
max_active_clusters,
current_stream,
)
if not skip_ref_check:
compiled_kernel(
a_tensor,
a_scale_tensor,
b_tensor,
c_tensor,
current_stream,
)
compare(
a_torch_cpu, b_torch_cpu, a_scale_torch_cpu, c_torch_gpu, c_dtype, tolerance
)
# Early return if no performance measurement is needed
if iterations <= 0:
return
def generate_tensors():
a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu = create_tensor_a(
l, m, k, a_major, a_dtype, scale_granularity_m, scale_granularity_k, b_dtype
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu,
b_dtype,
is_dynamic_layout=True,
assumed_align=get_divisibility(n if b_major == "n" else k),
)
c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major == "m", c_dtype)
c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like(
c_torch_cpu,
c_dtype,
is_dynamic_layout=True,
assumed_align=get_divisibility(m if c_major == "m" else n),
)
c_tensor = c_tensor.mark_compact_shape_dynamic(
mode=(0 if c_major == "m" else 1),
stride_order=(2, 1, 0) if c_major == "m" else (2, 0, 1),
divisibility=get_divisibility(m if c_major == "m" else n),
)
return testing.JitArguments(
a_tensor, a_scale_tensor, b_tensor, c_tensor, current_stream
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ c_torch_gpu.numel() * c_torch_gpu.element_size()
+ a_scale_torch_cpu.numel() * a_scale_torch_cpu.element_size()
if a_scale_torch_cpu is not None
else 0
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_kernel,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
def parse_comma_separated_ints(s: str) -> tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--mnkl", type=parse_comma_separated_ints, default=(128, 128, 128, 1)
)
parser.add_argument(
"--mma_tiler_mnk", type=parse_comma_separated_ints, default=(128, 128, 128)
)
parser.add_argument(
"--cluster_shape_mn", type=parse_comma_separated_ints, default=(1, 1)
)
parser.add_argument(
"--use_2cta_instrs",
action="store_true",
help="Enable 2CTA MMA instructions feature",
)
parser.add_argument(
"--a_dtype",
type=cutlass.dtype,
default=cutlass.Int4,
choices=[cutlass.Int8, cutlass.Uint8, cutlass.Int4],
)
parser.add_argument(
"--b_dtype",
type=cutlass.dtype,
default=cutlass.BFloat16,
choices=[cutlass.BFloat16, cutlass.Float16],
)
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="m")
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
parser.add_argument(
"--scale_granularity_m",
type=int,
default=1,
help="Scale granularity along M dimension.",
)
parser.add_argument(
"--scale_granularity_k",
type=int,
default=128,
help="Scale granularity along K dimension.",
)
parser.add_argument(
"--use_tma_store", action="store_true", help="Use tma store or not"
)
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
args = parser.parse_args()
run(
args.mnkl,
args.scale_granularity_m,
args.scale_granularity_k,
args.a_dtype,
args.b_dtype,
args.c_dtype,
args.acc_dtype,
args.a_major,
args.b_major,
args.c_major,
args.mma_tiler_mnk,
args.cluster_shape_mn,
args.use_2cta_instrs,
args.use_tma_store,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
)
print("PASS")