2195 lines
84 KiB
Python
2195 lines
84 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
from typing import Optional, Type, Tuple, Union
|
|
|
|
import cuda.bindings.driver as cuda
|
|
import torch
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
import cutlass.pipeline as pipeline
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
|
|
"""
|
|
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
|
|
using CUTE DSL.
|
|
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
|
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
|
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
|
|
|
This GEMM kernel supports the following features:
|
|
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
|
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions)
|
|
- Implements TMA multicast with cluster to reduce L2 memory traffic
|
|
- Support persistent tile scheduling to better overlap memory load/store with mma between tiles
|
|
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
|
|
|
|
This GEMM works as follows:
|
|
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
|
2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
|
3. EPILOGUE warp:
|
|
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
|
- Type convert C matrix to output type.
|
|
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
|
|
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
|
|
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
|
|
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
|
|
|
|
SM100 tcgen05.mma instructions operate as follows:
|
|
- Read matrix A from SMEM
|
|
- Read matrix B from SMEM
|
|
- Write accumulator to TMEM
|
|
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
|
|
|
Input arguments to this example is same as dense_gemm.py.
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/blackwell/dense_gemm_persistent.py \
|
|
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
|
|
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
|
--mnkl 8192,8192,8192,1 \
|
|
--use_tma_store --use_2cta_instrs
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/blackwell/dense_gemm_persistent.py \
|
|
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
|
|
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
|
--mnkl 8192,8192,8192,1 \
|
|
--use_tma_store --use_2cta_instrs \
|
|
--warmup_iterations 1 --iterations 10 --skip_ref_check
|
|
|
|
|
|
Constraints are same as dense_gemm.py:
|
|
* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
|
|
see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
|
|
* A/B tensor must have the same data type
|
|
* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
|
* Mma tiler N must be 32-256, step 32
|
|
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
* Cluster shape M must be multiple of 2 if use_2cta_instrs=True
|
|
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32,
|
|
Float16/BFloat16, and Int8/Uint8/Float8, respectively.
|
|
* OOB tiles are not allowed when TMA store is disabled
|
|
"""
|
|
|
|
|
|
class PersistentDenseGemmKernel:
|
|
"""This class implements batched matrix multiplication (C = A x B) with support for various data types
|
|
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
|
|
|
:param acc_dtype: Data type for accumulation during computation
|
|
:type acc_dtype: type[cutlass.Numeric]
|
|
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
|
|
:type use_2cta_instrs: bool
|
|
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
:param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results
|
|
:type use_tma_store: bool
|
|
|
|
:note: In current version, A and B tensor must have the same data type
|
|
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
|
|
|
|
:note: Supported A/B data types:
|
|
- TFloat32
|
|
- Float16/BFloat16
|
|
- Int8/Uint8
|
|
- Float8E4M3FN/Float8E5M2
|
|
|
|
:note: Supported accumulator data types:
|
|
- Float32 (for all floating point A/B data types)
|
|
- Float16 (only for fp16 and fp8 A/B data types)
|
|
- Int32 (only for uint8/int8 A/B data types)
|
|
|
|
:note: Supported C data types:
|
|
- Float32 (for float32 and int32 accumulator data types)
|
|
- Int32 (for float32 and int32 accumulator data types)
|
|
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
|
|
- Int8/Uint8 (for uint8/int8 accumulator data types)
|
|
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
|
|
|
|
:note: Constraints:
|
|
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
|
- MMA tiler N must be 32-256, step 32
|
|
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
|
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
|
|
Example:
|
|
>>> gemm = PersistentDenseGemmKernel(
|
|
... acc_dtype=cutlass.Float32,
|
|
... use_2cta_instrs=True,
|
|
... mma_tiler_mn=(128, 128),
|
|
... cluster_shape_mn=(2, 2)
|
|
... )
|
|
>>> gemm(a_tensor, b_tensor, c_tensor, max_active_clusters, stream)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
use_2cta_instrs: bool,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
use_tma_store: bool,
|
|
):
|
|
"""Initializes the configuration for a Blackwell dense GEMM kernel.
|
|
|
|
This configuration includes several key aspects:
|
|
|
|
1. MMA Instruction Settings (tcgen05):
|
|
- acc_dtype: Data types for MMA accumulator.
|
|
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
|
|
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
|
|
with cta_group=2 should be used.
|
|
|
|
2. Cluster Shape:
|
|
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
|
|
|
|
3. Output C tensor store mode:
|
|
- use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results.
|
|
|
|
:param acc_dtype: Data type of the accumulator.
|
|
:type acc_dtype: type[cutlass.Numeric]
|
|
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
|
|
:type use_2cta_instrs: bool
|
|
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
:param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor.
|
|
:type use_tma_store: bool
|
|
"""
|
|
|
|
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
self.use_2cta_instrs = use_2cta_instrs
|
|
self.cluster_shape_mn = cluster_shape_mn
|
|
# K dimension is deferred in _setup_attributes
|
|
self.mma_tiler = (*mma_tiler_mn, 1)
|
|
self.use_tma_store = use_tma_store
|
|
|
|
self.cta_group = (
|
|
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
)
|
|
|
|
self.occupancy = 1
|
|
# Set specialized warp ids
|
|
self.epilog_warp_id = (
|
|
0,
|
|
1,
|
|
2,
|
|
3,
|
|
)
|
|
self.mma_warp_id = 4
|
|
self.tma_warp_id = 5
|
|
self.threads_per_cta = 32 * len(
|
|
(self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)
|
|
)
|
|
# Set barrier id for cta sync, epilogue sync and tmem ptr sync
|
|
self.cta_sync_bar_id = 0
|
|
self.epilog_sync_bar_id = 1
|
|
self.tmem_ptr_sync_bar_id = 2
|
|
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
|
|
|
def _setup_attributes(self):
|
|
"""Set up configurations that are dependent on GEMM inputs
|
|
|
|
This method configures various attributes based on the input tensor properties
|
|
(data types, leading dimensions) and kernel settings:
|
|
- Configuring tiled MMA
|
|
- Computing MMA/cluster/tile shapes
|
|
- Computing cluster layout
|
|
- Computing multicast CTAs for A/B
|
|
- Computing epilogue subtile
|
|
- Setting up A/B/C stage counts in shared memory
|
|
- Computing A/B/C shared memory layout
|
|
- Computing tensor memory allocation columns
|
|
"""
|
|
# Configure tiled mma
|
|
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
self.a_dtype,
|
|
self.a_major_mode,
|
|
self.b_major_mode,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.mma_tiler[:2],
|
|
)
|
|
|
|
# Compute mma/cluster/tile shapes
|
|
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
|
mma_inst_tile_k = 4
|
|
self.mma_tiler = (
|
|
self.mma_tiler[0],
|
|
self.mma_tiler[1],
|
|
mma_inst_shape_k * mma_inst_tile_k,
|
|
)
|
|
self.cta_tile_shape_mnk = (
|
|
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
|
self.mma_tiler[1],
|
|
self.mma_tiler[2],
|
|
)
|
|
|
|
# Compute cluster layout
|
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
cute.make_layout((*self.cluster_shape_mn, 1)),
|
|
(tiled_mma.thr_id.shape,),
|
|
)
|
|
|
|
# Compute number of multicast CTAs for A/B
|
|
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
|
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
|
|
# Compute epilogue subtile
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
|
self.cta_tile_shape_mnk,
|
|
self.use_2cta_instrs,
|
|
self.c_layout,
|
|
self.c_dtype,
|
|
)
|
|
else:
|
|
self.epi_tile = self.cta_tile_shape_mnk[:2]
|
|
|
|
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
|
self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages(
|
|
tiled_mma,
|
|
self.mma_tiler,
|
|
self.a_dtype,
|
|
self.b_dtype,
|
|
self.epi_tile,
|
|
self.c_dtype,
|
|
self.c_layout,
|
|
self.smem_capacity,
|
|
self.occupancy,
|
|
self.use_tma_store,
|
|
)
|
|
|
|
# Compute A/B/C shared memory layout
|
|
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
tiled_mma,
|
|
self.mma_tiler,
|
|
self.a_dtype,
|
|
self.num_ab_stage,
|
|
)
|
|
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
tiled_mma,
|
|
self.mma_tiler,
|
|
self.b_dtype,
|
|
self.num_ab_stage,
|
|
)
|
|
self.c_smem_layout_staged = (
|
|
sm100_utils.make_smem_layout_epi(
|
|
self.c_dtype,
|
|
self.c_layout,
|
|
self.epi_tile,
|
|
self.num_c_stage,
|
|
)
|
|
if cutlass.const_expr(self.use_tma_store)
|
|
else None
|
|
)
|
|
|
|
# Compute the number of tensor memory allocation columns
|
|
self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
|
|
tiled_mma, self.mma_tiler, self.num_acc_stage
|
|
)
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
a: cute.Tensor,
|
|
b: cute.Tensor,
|
|
c: cute.Tensor,
|
|
max_active_clusters: cutlass.Constexpr,
|
|
stream: cuda.CUstream,
|
|
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
):
|
|
"""Execute the GEMM operation in steps:
|
|
- Setup static attributes before smem/grid/tma computation
|
|
- Setup TMA load/store atoms and tensors
|
|
- Compute grid size with regard to hardware constraints
|
|
- Define shared storage for kernel
|
|
- Launch the kernel synchronously
|
|
|
|
:param a: Input tensor A
|
|
:type a: cute.Tensor
|
|
:param b: Input tensor B
|
|
:type b: cute.Tensor
|
|
:param c: Output tensor C
|
|
:type c: cute.Tensor
|
|
:param max_active_clusters: Maximum number of active clusters
|
|
:type max_active_clusters: cutlass.Constexpr
|
|
:param stream: CUDA stream for asynchronous execution
|
|
:type stream: cuda.CUstream
|
|
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
|
|
:type epilogue_op: cutlass.Constexpr
|
|
:raises TypeError: If input data types are incompatible with the MMA instruction.
|
|
:raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
|
|
"""
|
|
# Setup static attributes before smem/grid/tma computation
|
|
self.a_dtype: Type[cutlass.Numeric] = a.element_type
|
|
self.b_dtype: Type[cutlass.Numeric] = b.element_type
|
|
self.c_dtype: Type[cutlass.Numeric] = c.element_type
|
|
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
|
|
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
|
|
self.c_layout = utils.LayoutEnum.from_tensor(c)
|
|
|
|
# Check if input data types are compatible with MMA instruction
|
|
if cutlass.const_expr(self.a_dtype != self.b_dtype):
|
|
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
|
|
|
|
# Setup attributes that dependent on gemm inputs
|
|
self._setup_attributes()
|
|
|
|
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
self.a_dtype,
|
|
self.a_major_mode,
|
|
self.b_major_mode,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.mma_tiler[:2],
|
|
)
|
|
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
|
|
|
# Setup TMA load for A
|
|
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
|
self.cluster_shape_mn, tiled_mma.thr_id
|
|
)
|
|
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
|
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
a_op,
|
|
a,
|
|
a_smem_layout,
|
|
self.mma_tiler,
|
|
tiled_mma,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
# Setup TMA load for B
|
|
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
|
self.cluster_shape_mn, tiled_mma.thr_id
|
|
)
|
|
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
|
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
|
b_op,
|
|
b,
|
|
b_smem_layout,
|
|
self.mma_tiler,
|
|
tiled_mma,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
|
|
|
|
# Setup TMA store for C
|
|
tma_atom_c = None
|
|
tma_tensor_c = None
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
c_cta_v_layout = cute.composition(
|
|
cute.make_identity_layout(c.shape), self.epi_tile
|
|
)
|
|
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
|
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileS2GOp(),
|
|
c,
|
|
epi_smem_layout,
|
|
c_cta_v_layout,
|
|
)
|
|
|
|
# Compute grid size
|
|
self.tile_sched_params, grid = self._compute_grid(
|
|
c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters
|
|
)
|
|
|
|
self.buffer_align_bytes = 1024
|
|
|
|
c_smem_size = (
|
|
cute.cosize(self.c_smem_layout_staged.outer)
|
|
if cutlass.const_expr(self.use_tma_store)
|
|
else 0
|
|
)
|
|
|
|
# Define shared storage for kernel
|
|
@cute.struct
|
|
class SharedStorage:
|
|
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
|
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
|
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
|
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
tmem_holding_buf: cutlass.Int32
|
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
sC: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.c_dtype,
|
|
c_smem_size,
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
sA: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
sB: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
|
|
self.shared_storage = SharedStorage
|
|
|
|
# Launch the kernel synchronously
|
|
self.kernel(
|
|
tiled_mma,
|
|
tma_atom_a,
|
|
tma_tensor_a,
|
|
tma_atom_b,
|
|
tma_tensor_b,
|
|
tma_atom_c,
|
|
tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c,
|
|
self.cluster_layout_vmnk,
|
|
self.a_smem_layout_staged,
|
|
self.b_smem_layout_staged,
|
|
self.c_smem_layout_staged,
|
|
self.epi_tile,
|
|
self.tile_sched_params,
|
|
epilogue_op,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=(*self.cluster_shape_mn, 1),
|
|
smem=self.shared_storage.size_in_bytes(),
|
|
stream=stream,
|
|
)
|
|
return
|
|
|
|
# GPU device kernel
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
tiled_mma: cute.TiledMma,
|
|
tma_atom_a: cute.CopyAtom,
|
|
mA_mkl: cute.Tensor,
|
|
tma_atom_b: cute.CopyAtom,
|
|
mB_nkl: cute.Tensor,
|
|
tma_atom_c: Optional[cute.CopyAtom],
|
|
mC_mnl: cute.Tensor,
|
|
cluster_layout_vmnk: cute.Layout,
|
|
a_smem_layout_staged: cute.ComposedLayout,
|
|
b_smem_layout_staged: cute.ComposedLayout,
|
|
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
|
|
epi_tile: cute.Tile,
|
|
tile_sched_params: utils.PersistentTileSchedulerParams,
|
|
epilogue_op: cutlass.Constexpr,
|
|
):
|
|
"""
|
|
GPU device kernel performing the Persistent batched GEMM computation.
|
|
"""
|
|
warp_idx = cute.arch.warp_idx()
|
|
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
|
|
|
#
|
|
# Prefetch tma desc
|
|
#
|
|
if warp_idx == self.tma_warp_id:
|
|
cpasync.prefetch_descriptor(tma_atom_a)
|
|
cpasync.prefetch_descriptor(tma_atom_b)
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
cpasync.prefetch_descriptor(tma_atom_c)
|
|
|
|
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
|
|
|
#
|
|
# Setup cta/thread coordinates
|
|
#
|
|
# Coords inside cluster
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
|
is_leader_cta = mma_tile_coord_v == 0
|
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
|
cute.arch.block_idx_in_cluster()
|
|
)
|
|
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
|
|
cta_rank_in_cluster
|
|
)
|
|
# Coord inside cta
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
|
|
#
|
|
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
|
|
#
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(self.shared_storage)
|
|
|
|
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
|
|
tmem_holding_buf = storage.tmem_holding_buf
|
|
|
|
# Initialize mainloop ab_pipeline (barrier) and states
|
|
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, num_tma_producer
|
|
)
|
|
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
|
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
|
num_stages=self.num_ab_stage,
|
|
producer_group=ab_pipeline_producer_group,
|
|
consumer_group=ab_pipeline_consumer_group,
|
|
tx_count=self.num_tma_load_bytes,
|
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
)
|
|
|
|
# Initialize acc_pipeline (barrier) and states
|
|
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
num_acc_consumer_threads = len(self.epilog_warp_id) * (
|
|
2 if use_2cta_instrs else 1
|
|
)
|
|
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
)
|
|
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
|
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
|
num_stages=self.num_acc_stage,
|
|
producer_group=acc_pipeline_producer_group,
|
|
consumer_group=acc_pipeline_consumer_group,
|
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
)
|
|
|
|
# Tensor memory dealloc barrier init
|
|
if use_2cta_instrs:
|
|
if warp_idx == self.tma_warp_id:
|
|
num_tmem_dealloc_threads = 32
|
|
with cute.arch.elect_one():
|
|
cute.arch.mbarrier_init(
|
|
tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads
|
|
)
|
|
cute.arch.mbarrier_init_fence()
|
|
|
|
# Cluster arrive after barrier init
|
|
if cute.size(self.cluster_shape_mn) > 1:
|
|
cute.arch.cluster_arrive_relaxed()
|
|
|
|
#
|
|
# Setup smem tensor A/B/C
|
|
#
|
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
sC = (
|
|
storage.sC.get_tensor(
|
|
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
|
|
)
|
|
if cutlass.const_expr(self.use_tma_store)
|
|
else None
|
|
)
|
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
sA = storage.sA.get_tensor(
|
|
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
sB = storage.sB.get_tensor(
|
|
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
|
)
|
|
|
|
#
|
|
# Compute multicast mask for A/B buffer full
|
|
#
|
|
a_full_mcast_mask = None
|
|
b_full_mcast_mask = None
|
|
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
|
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
|
)
|
|
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
|
|
)
|
|
|
|
#
|
|
# Local_tile partition global tensors
|
|
#
|
|
# (bM, bK, RestM, RestK, RestL)
|
|
gA_mkl = cute.local_tile(
|
|
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
|
)
|
|
# (bN, bK, RestN, RestK, RestL)
|
|
gB_nkl = cute.local_tile(
|
|
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
|
|
)
|
|
# (bM, bN, RestM, RestN, RestL)
|
|
gC_mnl = cute.local_tile(
|
|
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
|
)
|
|
k_block_cnt = cute.size(gA_mkl, mode=[3])
|
|
|
|
#
|
|
# Partition global tensor for TiledMMA_A/B/C
|
|
#
|
|
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
|
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
|
|
tCgA = thr_mma.partition_A(gA_mkl)
|
|
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
|
|
tCgB = thr_mma.partition_B(gB_nkl)
|
|
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
|
|
tCgC = thr_mma.partition_C(gC_mnl)
|
|
|
|
#
|
|
# Partition global/shared tensor for TMA load A/B
|
|
#
|
|
# TMA load A partition_S/D
|
|
a_cta_layout = cute.make_layout(
|
|
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
)
|
|
# ((atom_v, rest_v), STAGE)
|
|
# ((atom_v, rest_v), RestM, RestK, RestL)
|
|
tAsA, tAgA = cpasync.tma_partition(
|
|
tma_atom_a,
|
|
block_in_cluster_coord_vmnk[2],
|
|
a_cta_layout,
|
|
cute.group_modes(sA, 0, 3),
|
|
cute.group_modes(tCgA, 0, 3),
|
|
)
|
|
# TMA load B partition_S/D
|
|
b_cta_layout = cute.make_layout(
|
|
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
)
|
|
# ((atom_v, rest_v), STAGE)
|
|
# ((atom_v, rest_v), RestM, RestK, RestL)
|
|
tBsB, tBgB = cpasync.tma_partition(
|
|
tma_atom_b,
|
|
block_in_cluster_coord_vmnk[1],
|
|
b_cta_layout,
|
|
cute.group_modes(sB, 0, 3),
|
|
cute.group_modes(tCgB, 0, 3),
|
|
)
|
|
|
|
#
|
|
# Partition shared/tensor memory tensor for TiledMMA_A/B/C
|
|
#
|
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
tCrA = tiled_mma.make_fragment_A(sA)
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
tCrB = tiled_mma.make_fragment_B(sB)
|
|
# (MMA, MMA_M, MMA_N)
|
|
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_fake = tiled_mma.make_fragment_C(
|
|
cute.append(acc_shape, self.num_acc_stage)
|
|
)
|
|
|
|
#
|
|
# Cluster wait before tensor memory alloc
|
|
#
|
|
if cute.size(self.cluster_shape_mn) > 1:
|
|
cute.arch.cluster_wait()
|
|
else:
|
|
cute.arch.barrier(
|
|
barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta
|
|
)
|
|
|
|
#
|
|
# Specialized TMA load warp
|
|
#
|
|
|
|
if warp_idx == self.tma_warp_id:
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
ab_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
mma_tile_coord_mnl = (
|
|
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
|
cur_tile_coord[1],
|
|
cur_tile_coord[2],
|
|
)
|
|
|
|
#
|
|
# Slice to per mma tile index
|
|
#
|
|
# ((atom_v, rest_v), RestK)
|
|
tAgA_slice = tAgA[
|
|
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
|
|
]
|
|
# ((atom_v, rest_v), RestK)
|
|
tBgB_slice = tBgB[
|
|
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
|
]
|
|
|
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
|
|
ab_producer_state.reset_count()
|
|
peek_ab_empty_status = cutlass.Boolean(1)
|
|
if ab_producer_state.count < k_block_cnt:
|
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
|
ab_producer_state
|
|
)
|
|
#
|
|
# Tma load loop
|
|
#
|
|
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
|
|
# Conditionally wait for AB buffer empty
|
|
ab_pipeline.producer_acquire(
|
|
ab_producer_state, peek_ab_empty_status
|
|
)
|
|
|
|
# TMA load A/B
|
|
cute.copy(
|
|
tma_atom_a,
|
|
tAgA_slice[(None, ab_producer_state.count)],
|
|
tAsA[(None, ab_producer_state.index)],
|
|
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
mcast_mask=a_full_mcast_mask,
|
|
)
|
|
cute.copy(
|
|
tma_atom_b,
|
|
tBgB_slice[(None, ab_producer_state.count)],
|
|
tBsB[(None, ab_producer_state.index)],
|
|
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
|
mcast_mask=b_full_mcast_mask,
|
|
)
|
|
|
|
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
|
ab_producer_state.advance()
|
|
peek_ab_empty_status = cutlass.Boolean(1)
|
|
if ab_producer_state.count < k_block_cnt:
|
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
|
ab_producer_state
|
|
)
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
#
|
|
# Wait A/B buffer empty
|
|
#
|
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
|
|
#
|
|
# Specialized MMA warp
|
|
#
|
|
if warp_idx == self.mma_warp_id:
|
|
#
|
|
# Bar sync for retrieve tensor memory ptr from shared mem
|
|
#
|
|
tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id))
|
|
cute.arch.barrier(
|
|
barrier_id=self.tmem_ptr_sync_bar_id,
|
|
number_of_threads=tmem_ptr_read_threads,
|
|
)
|
|
|
|
#
|
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
#
|
|
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
self.acc_dtype,
|
|
alignment=16,
|
|
ptr_to_buffer_holding_addr=tmem_holding_buf,
|
|
)
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
|
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
ab_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
|
)
|
|
acc_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
mma_tile_coord_mnl = (
|
|
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
|
cur_tile_coord[1],
|
|
cur_tile_coord[2],
|
|
)
|
|
|
|
# Set tensor memory buffer for current tile
|
|
# (MMA, MMA_M, MMA_N)
|
|
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
|
|
|
# Peek (try_wait) AB buffer full for k_block = 0
|
|
ab_consumer_state.reset_count()
|
|
peek_ab_full_status = cutlass.Boolean(1)
|
|
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
|
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
|
ab_consumer_state
|
|
)
|
|
|
|
#
|
|
# Wait for accumulator buffer empty
|
|
#
|
|
if is_leader_cta:
|
|
acc_pipeline.producer_acquire(acc_producer_state)
|
|
|
|
#
|
|
# Reset the ACCUMULATE field for each tile
|
|
#
|
|
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
|
|
|
#
|
|
# Mma mainloop
|
|
#
|
|
for k_block in range(k_block_cnt):
|
|
if is_leader_cta:
|
|
# Conditionally wait for AB buffer full
|
|
ab_pipeline.consumer_wait(
|
|
ab_consumer_state, peek_ab_full_status
|
|
)
|
|
|
|
# tCtAcc += tCrA * tCrB
|
|
num_kphases = cute.size(tCrA, mode=[2])
|
|
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
|
kphase_coord = (
|
|
None,
|
|
None,
|
|
kphase_idx,
|
|
ab_consumer_state.index,
|
|
)
|
|
|
|
cute.gemm(
|
|
tiled_mma,
|
|
tCtAcc,
|
|
tCrA[kphase_coord],
|
|
tCrB[kphase_coord],
|
|
tCtAcc,
|
|
)
|
|
# Enable accumulate on tCtAcc after first kphase
|
|
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
|
|
# Async arrive AB buffer empty
|
|
ab_pipeline.consumer_release(ab_consumer_state)
|
|
|
|
# Peek (try_wait) AB buffer full for k_block = k_block + 1
|
|
ab_consumer_state.advance()
|
|
peek_ab_full_status = cutlass.Boolean(1)
|
|
if ab_consumer_state.count < k_block_cnt:
|
|
if is_leader_cta:
|
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
|
ab_consumer_state
|
|
)
|
|
|
|
#
|
|
# Async arrive accumulator buffer full
|
|
#
|
|
if is_leader_cta:
|
|
acc_pipeline.producer_commit(acc_producer_state)
|
|
acc_producer_state.advance()
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
#
|
|
# Wait for accumulator buffer empty
|
|
#
|
|
acc_pipeline.producer_tail(acc_producer_state)
|
|
#
|
|
# Specialized epilogue warps
|
|
#
|
|
if warp_idx < self.mma_warp_id:
|
|
#
|
|
# Alloc tensor memory buffer
|
|
#
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
cute.arch.alloc_tmem(
|
|
self.num_tmem_alloc_cols,
|
|
tmem_holding_buf,
|
|
is_two_cta=use_2cta_instrs,
|
|
)
|
|
|
|
#
|
|
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
#
|
|
tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id))
|
|
cute.arch.barrier(
|
|
barrier_id=self.tmem_ptr_sync_bar_id,
|
|
number_of_threads=tmem_ptr_read_threads,
|
|
)
|
|
|
|
#
|
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
#
|
|
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
self.acc_dtype,
|
|
alignment=16,
|
|
ptr_to_buffer_holding_addr=tmem_holding_buf,
|
|
)
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
|
|
|
#
|
|
# Partition for epilogue
|
|
#
|
|
epi_tidx = tidx
|
|
(
|
|
tiled_copy_t2r,
|
|
tTR_tAcc_base,
|
|
tTR_rAcc,
|
|
) = self.epilog_tmem_copy_and_partition(
|
|
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
|
|
)
|
|
|
|
tTR_rC = None
|
|
tiled_copy_r2s = None
|
|
simt_atom = None
|
|
tRS_rC = None
|
|
tRS_sC = None
|
|
bSG_sC = None
|
|
bSG_gC_partitioned = None
|
|
tTR_gC_partitioned = None
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype)
|
|
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
|
|
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
|
)
|
|
(
|
|
tma_atom_c,
|
|
bSG_sC,
|
|
bSG_gC_partitioned,
|
|
) = self.epilog_gmem_copy_and_partition(
|
|
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
|
|
)
|
|
else:
|
|
(
|
|
simt_atom,
|
|
tTR_rC,
|
|
tTR_gC_partitioned,
|
|
) = self.epilog_gmem_copy_and_partition(
|
|
epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC
|
|
)
|
|
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
acc_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
|
)
|
|
|
|
c_pipeline = None
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
# Threads/warps participating in tma store pipeline
|
|
c_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
32 * len(self.epilog_warp_id),
|
|
32 * len(self.epilog_warp_id),
|
|
)
|
|
c_pipeline = pipeline.PipelineTmaStore.create(
|
|
num_stages=self.num_c_stage,
|
|
producer_group=c_producer_group,
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
mma_tile_coord_mnl = (
|
|
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
|
cur_tile_coord[1],
|
|
cur_tile_coord[2],
|
|
)
|
|
|
|
#
|
|
# Slice to per mma tile index
|
|
#
|
|
bSG_gC = None
|
|
tTR_gC = None
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
bSG_gC = bSG_gC_partitioned[
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
*mma_tile_coord_mnl,
|
|
)
|
|
]
|
|
else:
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
|
tTR_gC = tTR_gC_partitioned[
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
*mma_tile_coord_mnl,
|
|
)
|
|
]
|
|
|
|
# Set tensor memory buffer for current tile
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
|
tTR_tAcc = tTR_tAcc_base[
|
|
(None, None, None, None, None, acc_consumer_state.index)
|
|
]
|
|
|
|
#
|
|
# Wait for accumulator buffer full
|
|
#
|
|
acc_pipeline.consumer_wait(acc_consumer_state)
|
|
|
|
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
|
else:
|
|
tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC))
|
|
|
|
#
|
|
# Store accumulator to global memory in subtiles
|
|
#
|
|
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
|
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
|
|
for subtile_idx in cutlass.range(subtile_cnt):
|
|
#
|
|
# Load accumulator from tensor memory buffer to register
|
|
#
|
|
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
|
|
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
|
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
#
|
|
# Convert to C type
|
|
#
|
|
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
|
acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
|
|
tRS_rC.store(acc_vec)
|
|
|
|
#
|
|
# Store C to shared memory
|
|
#
|
|
c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
|
|
cute.copy(
|
|
tiled_copy_r2s,
|
|
tRS_rC,
|
|
tRS_sC[(None, None, None, c_buffer)],
|
|
)
|
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
epilog_threads = 32 * len(self.epilog_warp_id)
|
|
cute.arch.barrier(
|
|
barrier_id=self.epilog_sync_bar_id,
|
|
number_of_threads=epilog_threads,
|
|
)
|
|
|
|
#
|
|
# TMA store C to global memory
|
|
#
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
cute.copy(
|
|
tma_atom_c,
|
|
bSG_sC[(None, c_buffer)],
|
|
bSG_gC[(None, subtile_idx)],
|
|
)
|
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
c_pipeline.producer_commit()
|
|
c_pipeline.producer_acquire()
|
|
cute.arch.barrier(
|
|
barrier_id=self.epilog_sync_bar_id,
|
|
number_of_threads=epilog_threads,
|
|
)
|
|
else:
|
|
#
|
|
# Convert to C type
|
|
#
|
|
acc_vec = tTR_rAcc.load()
|
|
acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
|
|
tTR_rC.store(acc_vec)
|
|
|
|
#
|
|
# Store C to global memory
|
|
#
|
|
cute.copy(
|
|
simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]
|
|
)
|
|
|
|
#
|
|
# Async arrive accumulator buffer empty
|
|
#
|
|
with cute.arch.elect_one():
|
|
acc_pipeline.consumer_release(acc_consumer_state)
|
|
acc_consumer_state.advance()
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
#
|
|
# Dealloc the tensor memory buffer
|
|
#
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs)
|
|
epilog_threads = 32 * len(self.epilog_warp_id)
|
|
cute.arch.barrier(
|
|
barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads
|
|
)
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
if use_2cta_instrs:
|
|
cute.arch.mbarrier_arrive(
|
|
tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1
|
|
)
|
|
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
cute.arch.dealloc_tmem(
|
|
tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs
|
|
)
|
|
#
|
|
# Wait for C store complete
|
|
#
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
c_pipeline.producer_tail()
|
|
|
|
def epilog_tmem_copy_and_partition(
|
|
self,
|
|
tidx: cutlass.Int32,
|
|
tAcc: cute.Tensor,
|
|
gC_mnl: cute.Tensor,
|
|
epi_tile: cute.Tile,
|
|
use_2cta_instrs: Union[cutlass.Boolean, bool],
|
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
"""
|
|
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
|
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param tAcc: The accumulator tensor to be copied and partitioned
|
|
:type tAcc: cute.Tensor
|
|
:param gC_mnl: The global tensor C
|
|
:type gC_mnl: cute.Tensor
|
|
:param epi_tile: The epilogue tiler
|
|
:type epi_tile: cute.Tile
|
|
:param use_2cta_instrs: Whether use_2cta_instrs is enabled
|
|
:type use_2cta_instrs: bool
|
|
|
|
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
|
|
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
- tTR_tAcc: The partitioned accumulator tensor
|
|
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
|
|
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
"""
|
|
# Make tiledCopy for tensor memory load
|
|
copy_atom_t2r = sm100_utils.get_tmem_load_op(
|
|
self.cta_tile_shape_mnk,
|
|
self.c_layout,
|
|
self.c_dtype,
|
|
self.acc_dtype,
|
|
epi_tile,
|
|
use_2cta_instrs,
|
|
)
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
|
|
tAcc_epi = cute.flat_divide(
|
|
tAcc[((None, None), 0, 0, None)],
|
|
epi_tile,
|
|
)
|
|
# (EPI_TILE_M, EPI_TILE_N)
|
|
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
|
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
|
|
)
|
|
|
|
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
|
|
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
|
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
|
gC_mnl_epi = cute.flat_divide(
|
|
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
|
)
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
|
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
|
# (T2R, T2R_M, T2R_N)
|
|
tTR_rAcc = cute.make_fragment(
|
|
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
|
|
)
|
|
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
|
|
|
def epilog_smem_copy_and_partition(
|
|
self,
|
|
tiled_copy_t2r: cute.TiledCopy,
|
|
tTR_rC: cute.Tensor,
|
|
tidx: cutlass.Int32,
|
|
sC: cute.Tensor,
|
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
"""
|
|
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
|
|
|
|
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
:type tiled_copy_t2r: cute.TiledCopy
|
|
:param tTR_rC: The partitioned accumulator tensor
|
|
:type tTR_rC: cute.Tensor
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param sC: The shared memory tensor to be copied and partitioned
|
|
:type sC: cute.Tensor
|
|
:type sepi: cute.Tensor
|
|
|
|
:return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where:
|
|
- tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
|
|
- tRS_rC: The partitioned tensor C (register source)
|
|
- tRS_sC: The partitioned tensor C (smem destination)
|
|
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
"""
|
|
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
|
)
|
|
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
tRS_sC = thr_copy_r2s.partition_D(sC)
|
|
# (R2S, R2S_M, R2S_N)
|
|
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
|
|
return tiled_copy_r2s, tRS_rC, tRS_sC
|
|
|
|
def epilog_gmem_copy_and_partition(
|
|
self,
|
|
tidx: cutlass.Int32,
|
|
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
gC_mnl: cute.Tensor,
|
|
epi_tile: cute.Tile,
|
|
sC: cute.Tensor,
|
|
) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
|
|
"""Make tiledCopy for global memory store, then use it to:
|
|
- partition register array (source) and global memory (destination) for none TMA store version;
|
|
- partition shared memory (source) and global memory (destination) for TMA store version.
|
|
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
|
|
:type atom: cute.CopyAtom or cute.TiledCopy
|
|
:param gC_mnl: The global tensor C
|
|
:type gC_mnl: cute.Tensor
|
|
:param epi_tile: The epilogue tiler
|
|
:type epi_tile: cute.Tile
|
|
:param sC: The shared memory tensor to be copied and partitioned
|
|
:type sC: cute.Tensor
|
|
|
|
:return: A tuple containing either:
|
|
- For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where:
|
|
- tma_atom_c: The TMA copy atom
|
|
- bSG_sC: The partitioned shared memory tensor C
|
|
- bSG_gC: The partitioned global tensor C
|
|
- For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where:
|
|
- simt_atom: The SIMT copy atom
|
|
- tTR_rC: The register tensor C
|
|
- tTR_gC: The partitioned global tensor C
|
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
|
"""
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
|
gC_epi = cute.flat_divide(
|
|
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
|
)
|
|
if cutlass.const_expr(self.use_tma_store):
|
|
tma_atom_c = atom
|
|
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
|
|
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
|
|
bSG_sC, bSG_gC = cpasync.tma_partition(
|
|
tma_atom_c,
|
|
0,
|
|
cute.make_layout(1),
|
|
sC_for_tma_partition,
|
|
gC_for_tma_partition,
|
|
)
|
|
return tma_atom_c, bSG_sC, bSG_gC
|
|
else:
|
|
tiled_copy_t2r = atom
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
|
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
|
tTR_gC = thr_copy_t2r.partition_D(gC_epi)
|
|
# (T2R, T2R_M, T2R_N)
|
|
tTR_rC = cute.make_fragment(
|
|
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype
|
|
)
|
|
simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype)
|
|
return simt_atom, tTR_rC, tTR_gC
|
|
|
|
@staticmethod
|
|
def _compute_stages(
|
|
tiled_mma: cute.TiledMma,
|
|
mma_tiler_mnk: Tuple[int, int, int],
|
|
a_dtype: Type[cutlass.Numeric],
|
|
b_dtype: Type[cutlass.Numeric],
|
|
epi_tile: cute.Tile,
|
|
c_dtype: Type[cutlass.Numeric],
|
|
c_layout: utils.LayoutEnum,
|
|
smem_capacity: int,
|
|
occupancy: int,
|
|
use_tma_store: bool,
|
|
) -> Tuple[int, int, int]:
|
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
|
|
:param tiled_mma: The tiled MMA object defining the core computation.
|
|
:type tiled_mma: cute.TiledMma
|
|
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
|
|
:type mma_tiler_mnk: tuple[int, int, int]
|
|
:param a_dtype: Data type of operand A.
|
|
:type a_dtype: type[cutlass.Numeric]
|
|
:param b_dtype: Data type of operand B.
|
|
:type b_dtype: type[cutlass.Numeric]
|
|
:param epi_tile: The epilogue tile shape.
|
|
:type epi_tile: cute.Tile
|
|
:param c_dtype: Data type of operand C (output).
|
|
:type c_dtype: type[cutlass.Numeric]
|
|
:param c_layout: Layout enum of operand C.
|
|
:type c_layout: utils.LayoutEnum
|
|
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
:type smem_capacity: int
|
|
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
:type occupancy: int
|
|
:param use_tma_store: Whether TMA store is enabled.
|
|
:type use_tma_store: bool
|
|
|
|
:return: A tuple containing the computed number of stages for:
|
|
(ACC stages, A/B operand stages, C stages)
|
|
:rtype: tuple[int, int, int]
|
|
"""
|
|
# Default ACC stages
|
|
num_acc_stage = 2
|
|
|
|
# Default C stages
|
|
num_c_stage = 2 if use_tma_store else 0
|
|
|
|
# Calculate smem layout and size for one stage of A, B, and C
|
|
a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(
|
|
tiled_mma,
|
|
mma_tiler_mnk,
|
|
a_dtype,
|
|
1, # a tmp 1 stage is provided
|
|
)
|
|
b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
|
|
tiled_mma,
|
|
mma_tiler_mnk,
|
|
b_dtype,
|
|
1, # a tmp 1 stage is provided
|
|
)
|
|
c_smem_layout_staged_one = (
|
|
sm100_utils.make_smem_layout_epi(
|
|
c_dtype,
|
|
c_layout,
|
|
epi_tile,
|
|
1,
|
|
)
|
|
if use_tma_store
|
|
else None
|
|
)
|
|
ab_bytes_per_stage = cute.size_in_bytes(
|
|
a_dtype, a_smem_layout_stage_one
|
|
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
|
|
mbar_helpers_bytes = 1024
|
|
c_bytes_per_stage = (
|
|
cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
|
if use_tma_store
|
|
else 0
|
|
)
|
|
c_bytes = c_bytes_per_stage * num_c_stage
|
|
|
|
# Calculate A/B stages:
|
|
# Start with total smem per CTA (capacity / occupancy)
|
|
# Subtract reserved bytes and initial C stages bytes
|
|
# Divide remaining by bytes needed per A/B stage
|
|
num_ab_stage = (
|
|
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
|
|
) // ab_bytes_per_stage
|
|
|
|
# Refine epilogue stages:
|
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
# Add remaining unused smem to epilogue
|
|
if use_tma_store:
|
|
num_c_stage += (
|
|
smem_capacity
|
|
- occupancy * ab_bytes_per_stage * num_ab_stage
|
|
- occupancy * (mbar_helpers_bytes + c_bytes)
|
|
) // (occupancy * c_bytes_per_stage)
|
|
return num_acc_stage, num_ab_stage, num_c_stage
|
|
|
|
@staticmethod
|
|
def _compute_grid(
|
|
c: cute.Tensor,
|
|
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
max_active_clusters: cutlass.Constexpr,
|
|
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
|
|
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
|
|
|
|
:param c: The output tensor C
|
|
:type c: cute.Tensor
|
|
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
:type cta_tile_shape_mnk: tuple[int, int, int]
|
|
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
|
:type cluster_shape_mn: tuple[int, int]
|
|
:param max_active_clusters: Maximum number of active clusters.
|
|
:type max_active_clusters: cutlass.Constexpr
|
|
|
|
:return: A tuple containing:
|
|
- tile_sched_params: Parameters for the persistent tile scheduler.
|
|
- grid: Grid shape for kernel launch.
|
|
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
|
|
"""
|
|
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
|
|
gc = cute.zipped_divide(c, tiler=c_shape)
|
|
num_ctas_mnl = gc[(0, (None, None, None))].shape
|
|
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
|
|
|
tile_sched_params = utils.PersistentTileSchedulerParams(
|
|
num_ctas_mnl, cluster_shape_mnl
|
|
)
|
|
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
|
tile_sched_params, max_active_clusters
|
|
)
|
|
|
|
return tile_sched_params, grid
|
|
|
|
@staticmethod
|
|
def _compute_num_tmem_alloc_cols(
|
|
tiled_mma: cute.TiledMma,
|
|
mma_tiler: Tuple[int, int, int],
|
|
num_acc_stage: int,
|
|
) -> int:
|
|
"""
|
|
Compute the number of tensor memory allocation columns.
|
|
|
|
:param tiled_mma: The tiled MMA object defining the core computation.
|
|
:type tiled_mma: cute.TiledMma
|
|
:param mma_tiler: The shape (M, N, K) of the MMA tile.
|
|
:type mma_tiler: tuple[int, int, int]
|
|
:param num_acc_stage: The stage of the accumulator tensor.
|
|
:type num_acc_stage: int
|
|
|
|
:return: The number of tensor memory allocation columns.
|
|
:rtype: int
|
|
"""
|
|
acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2])
|
|
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage))
|
|
num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
|
|
|
|
return num_tmem_alloc_cols
|
|
|
|
@staticmethod
|
|
def is_valid_dtypes(
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
) -> bool:
|
|
"""
|
|
Check if the dtypes are valid
|
|
|
|
:param ab_dtype: The data type of the A and B operands
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: The data type of the accumulator
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: The data type of the output tensor
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
|
|
:return: True if the dtypes are valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
is_valid = True
|
|
if ab_dtype not in {
|
|
cutlass.Float16,
|
|
cutlass.BFloat16,
|
|
cutlass.TFloat32,
|
|
cutlass.Uint8,
|
|
cutlass.Int8,
|
|
cutlass.Float8E4M3FN,
|
|
cutlass.Float8E5M2,
|
|
}:
|
|
is_valid = False
|
|
if (
|
|
acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}
|
|
or acc_dtype == cutlass.Float16
|
|
and ab_dtype
|
|
not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
|
|
or acc_dtype == cutlass.Int32
|
|
and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
|
|
):
|
|
is_valid = False
|
|
if (
|
|
acc_dtype == cutlass.Float32
|
|
and c_dtype
|
|
not in {
|
|
cutlass.Float32,
|
|
cutlass.Float16,
|
|
cutlass.BFloat16,
|
|
cutlass.Float8E4M3FN,
|
|
cutlass.Float8E5M2,
|
|
cutlass.Int32,
|
|
cutlass.Int8,
|
|
cutlass.Uint8,
|
|
}
|
|
or acc_dtype == cutlass.Float16
|
|
and c_dtype
|
|
not in {
|
|
cutlass.BFloat16,
|
|
cutlass.Float16,
|
|
}
|
|
or acc_dtype == cutlass.Int32
|
|
and c_dtype
|
|
not in {
|
|
cutlass.BFloat16,
|
|
cutlass.Float16,
|
|
cutlass.Float32,
|
|
cutlass.Int32,
|
|
cutlass.Int8,
|
|
cutlass.Uint8,
|
|
}
|
|
):
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def is_valid_mma_tiler_and_cluster_shape(
|
|
use_2cta_instrs: bool,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
) -> bool:
|
|
"""
|
|
Check if the mma tiler and cluster shape are valid
|
|
|
|
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
:type use_2cta_instrs: bool
|
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
|
|
:return: True if the mma tiler and cluster shape are valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
is_valid = True
|
|
# Skip invalid mma tile shape
|
|
if not (
|
|
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
|
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
|
):
|
|
is_valid = False
|
|
if mma_tiler_mn[1] not in range(32, 257, 32):
|
|
is_valid = False
|
|
# Skip illegal cluster shape
|
|
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
|
is_valid = False
|
|
# Skip invalid cluster shape
|
|
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
|
|
if (
|
|
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
|
or cluster_shape_mn[0] <= 0
|
|
or cluster_shape_mn[1] <= 0
|
|
or not is_power_of_2(cluster_shape_mn[0])
|
|
or not is_power_of_2(cluster_shape_mn[1])
|
|
):
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def is_valid_tensor_alignment(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
l: int,
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
) -> bool:
|
|
"""
|
|
Check if the tensor alignment is valid
|
|
|
|
:param m: The number of rows in the A tensor
|
|
:type m: int
|
|
:param n: The number of columns in the B tensor
|
|
:type n: int
|
|
:param k: The number of columns in the A tensor
|
|
:type k: int
|
|
:param l: The number of columns in the C tensor
|
|
:type l: int
|
|
:param ab_dtype: The data type of the A and B operands
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: The data type of the output tensor
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
:param a_major: The major axis of the A tensor
|
|
:type a_major: str
|
|
:param b_major: The major axis of the B tensor
|
|
:type b_major: str
|
|
:param c_major: The major axis of the C tensor
|
|
:type c_major: str
|
|
|
|
:return: True if the problem shape is valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
is_valid = True
|
|
|
|
def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
|
|
major_mode_idx = 0 if is_mode0_major else 1
|
|
num_major_elements = tensor_shape[major_mode_idx]
|
|
num_contiguous_elements = 16 * 8 // dtype.width
|
|
return num_major_elements % num_contiguous_elements == 0
|
|
|
|
if (
|
|
not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
|
|
or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
|
|
or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
|
|
):
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def is_valid_epilog_store_option(
|
|
use_2cta_instrs: bool,
|
|
use_tma_store: bool,
|
|
m: int,
|
|
n: int,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
) -> bool:
|
|
"""
|
|
Check if the epilogue store option is valid
|
|
|
|
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
:type use_2cta_instrs: bool
|
|
:param use_tma_store: Whether to use TMA store
|
|
:type use_tma_store: bool
|
|
:param m: The number of rows in the A tensor
|
|
:type m: int
|
|
:param n: The number of columns in the B tensor
|
|
:type n: int
|
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
|
|
:return: True if the epilogue store option is valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
|
|
is_valid = True
|
|
# None TMA store version does not have predication, can not support OOB tiles
|
|
cta_tile_shape_mn = (
|
|
mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
|
|
mma_tiler_mn[1],
|
|
)
|
|
if not use_tma_store:
|
|
if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0):
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def can_implement(
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
use_2cta_instrs: bool,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
use_tma_store: bool,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
l: int,
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
) -> bool:
|
|
"""
|
|
Check if the gemm can be implemented
|
|
|
|
:param ab_dtype: The data type of the A and B operands
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: The data type of the accumulator
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: The data type of the output tensor
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
:type use_2cta_instrs: bool
|
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
:param use_tma_store: Whether to use TMA store
|
|
:type use_tma_store: bool
|
|
:param m: The number of rows in the A tensor
|
|
:type m: int
|
|
:param n: The number of columns in the B tensor
|
|
:type n: int
|
|
:param k: The number of columns in the A tensor
|
|
:type k: int
|
|
:param l: The number of columns in the C tensor
|
|
:type l: int
|
|
:param a_major: The major axis of the A tensor
|
|
:type a_major: str
|
|
:param b_major: The major axis of the B tensor
|
|
:type b_major: str
|
|
:param c_major: The major axis of the C tensor
|
|
:type c_major: str
|
|
|
|
:return: True if the gemm can be implemented, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
can_implement = True
|
|
# Skip unsupported types
|
|
if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype):
|
|
can_implement = False
|
|
# Skip invalid mma tile shape and cluster shape
|
|
if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
|
|
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
|
|
):
|
|
can_implement = False
|
|
# Skip illegal problem shape for load/store alignment
|
|
if not PersistentDenseGemmKernel.is_valid_tensor_alignment(
|
|
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
|
|
):
|
|
can_implement = False
|
|
# Skip invalid epilogue store option
|
|
if not PersistentDenseGemmKernel.is_valid_epilog_store_option(
|
|
use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn
|
|
):
|
|
can_implement = False
|
|
return can_implement
|
|
|
|
|
|
def run(
|
|
mnkl: Tuple[int, int, int, int],
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
mma_tiler_mn: Tuple[int, int] = (256, 256),
|
|
cluster_shape_mn: Tuple[int, int] = (2, 1),
|
|
use_2cta_instrs: bool = True,
|
|
use_tma_store: bool = True,
|
|
tolerance: float = 1e-01,
|
|
warmup_iterations: int = 0,
|
|
iterations: int = 1,
|
|
skip_ref_check: bool = False,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
|
|
|
|
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
|
optionally performs reference validation, and benchmarks the execution performance.
|
|
|
|
:param mnkl: Problem size (M, N, K, L)
|
|
:type mnkl: Tuple[int, int, int, int]
|
|
:param ab_dtype: Data type for input tensors A and B
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: Data type for output tensor C
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: Data type for accumulation during matrix multiplication
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
|
:type a_major/b_major/c_major: str
|
|
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
|
|
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
:type mma_tiler_mn: Tuple[int, int], optional
|
|
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
|
|
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
:type cluster_shape_mn: Tuple[int, int], optional
|
|
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
|
|
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
:type use_2cta_instrs: bool, optional
|
|
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
|
|
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
|
|
:type use_tma_store: bool, optional
|
|
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
|
:type tolerance: float, optional
|
|
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
|
:type warmup_iterations: int, optional
|
|
:param iterations: Number of benchmark iterations to run, defaults to 1
|
|
:type iterations: int, optional
|
|
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
|
:type skip_ref_check: bool, optional
|
|
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
|
:type use_cold_l2: bool, optional
|
|
:raises RuntimeError: If CUDA GPU is not available
|
|
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
|
:return: Execution time of the GEMM kernel
|
|
:rtype: float
|
|
"""
|
|
print(f"Running Blackwell Persistent Dense GEMM test with:")
|
|
print(f"mnkl: {mnkl}")
|
|
print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}")
|
|
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
|
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
|
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
|
print(f"Use TMA Store: {'True' if use_tma_store else 'False'}")
|
|
print(f"Tolerance: {tolerance}")
|
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
print(f"Iterations: {iterations}")
|
|
print(f"Skip reference checking: {skip_ref_check}")
|
|
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
|
|
|
# Unpack parameters
|
|
m, n, k, l = mnkl
|
|
|
|
# Skip unsupported testcase
|
|
if not PersistentDenseGemmKernel.can_implement(
|
|
ab_dtype,
|
|
acc_dtype,
|
|
c_dtype,
|
|
use_2cta_instrs,
|
|
mma_tiler_mn,
|
|
cluster_shape_mn,
|
|
use_tma_store,
|
|
m,
|
|
n,
|
|
k,
|
|
l,
|
|
a_major,
|
|
b_major,
|
|
c_major,
|
|
):
|
|
raise TypeError(
|
|
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
|
|
)
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
torch.manual_seed(1111)
|
|
|
|
# Create and permute tensor A/B/C
|
|
def create_and_permute_tensor(
|
|
l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True
|
|
):
|
|
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
|
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
|
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
|
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
|
is_unsigned = dtype in {cutlass.Uint8}
|
|
# Temporarily use uint8 as torch does not support fp8 type
|
|
torch_dtype = (
|
|
cutlass_torch.dtype(dtype)
|
|
if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}
|
|
else torch.uint8
|
|
)
|
|
|
|
# Create dtype torch tensor (cpu)
|
|
torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor(
|
|
shape,
|
|
torch_dtype,
|
|
permute_order=permute_order,
|
|
init_type=cutlass_torch.TensorInitType.RANDOM,
|
|
init_config=cutlass_torch.RandomInitConfig(
|
|
min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2
|
|
),
|
|
)
|
|
# Create dtype torch tensor (gpu)
|
|
torch_tensor = torch_tensor_cpu.cuda()
|
|
|
|
# Create f32 torch tensor (cpu)
|
|
f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32)
|
|
|
|
# Create dtype cute tensor (gpu)
|
|
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
|
|
cute_tensor.element_type = dtype
|
|
if is_dynamic_layout:
|
|
cute_tensor = cute_tensor.mark_layout_dynamic(
|
|
leading_dim=(0 if is_mode0_major else 1)
|
|
)
|
|
cute_tensor = cutlass_torch.convert_cute_tensor(
|
|
f32_torch_tensor,
|
|
cute_tensor,
|
|
dtype,
|
|
is_dynamic_layout=is_dynamic_layout,
|
|
)
|
|
|
|
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
|
|
|
|
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
|
|
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
|
|
)
|
|
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
|
|
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
|
|
)
|
|
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
|
|
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
|
|
)
|
|
|
|
# Configure gemm kernel
|
|
gemm = PersistentDenseGemmKernel(
|
|
acc_dtype,
|
|
use_2cta_instrs,
|
|
mma_tiler_mn,
|
|
cluster_shape_mn,
|
|
use_tma_store,
|
|
)
|
|
|
|
# Compute max active clusters on current device
|
|
hardware_info = cutlass.utils.HardwareInfo()
|
|
max_active_clusters = hardware_info.get_max_active_clusters(
|
|
cluster_shape_mn[0] * cluster_shape_mn[1]
|
|
)
|
|
|
|
# Get current CUDA stream from PyTorch
|
|
torch_stream = torch.cuda.current_stream()
|
|
# Get the raw stream pointer as a CUstream
|
|
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
# Compile gemm kernel
|
|
compiled_gemm = cute.compile(
|
|
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
|
|
)
|
|
|
|
if not skip_ref_check:
|
|
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
|
|
if ab_dtype in {
|
|
cutlass.Int8,
|
|
cutlass.Uint8,
|
|
cutlass.Float8E4M3FN,
|
|
cutlass.Float8E5M2,
|
|
}:
|
|
ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu())
|
|
else:
|
|
ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu()
|
|
|
|
# Copy gpu result back
|
|
gpu_c = c_torch.cpu()
|
|
|
|
# Convert ref to c_type
|
|
if c_dtype == cutlass.Float32:
|
|
ref_c = ref
|
|
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
|
|
# m major: (l, n, m) -> (m, n, l)
|
|
# n major: (l, m, n) -> (m, n, l)
|
|
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
|
|
shape = (l, m, n) if c_major == "n" else (l, n, m)
|
|
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
|
|
shape,
|
|
torch.uint8,
|
|
permute_order=permute_order,
|
|
init_type=cutlass_torch.TensorInitType.SKIP,
|
|
).cuda()
|
|
# Create dtype cute tensor (gpu)
|
|
ref_c_tensor = from_dlpack(
|
|
f8_torch_tensor, assumed_align=16
|
|
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
|
ref_c_tensor.element_type = c_dtype
|
|
ref_c_tensor = cutlass_torch.convert_cute_tensor(
|
|
ref,
|
|
ref_c_tensor,
|
|
c_dtype,
|
|
is_dynamic_layout=True,
|
|
)
|
|
|
|
ref_c = f8_torch_tensor.cpu()
|
|
else:
|
|
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
|
|
|
|
# Reference checking ref_c and gpu_c
|
|
torch.testing.assert_close(
|
|
gpu_c,
|
|
ref_c,
|
|
atol=tolerance,
|
|
rtol=1e-05,
|
|
)
|
|
|
|
def generate_tensors():
|
|
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
b_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
c_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
a_torch_cpu.numel() * a_torch_cpu.element_size()
|
|
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
|
|
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
exec_time = testing.benchmark(
|
|
compiled_gemm,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=current_stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
return exec_time # Return execution time in microseconds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
|
try:
|
|
return tuple(int(x.strip()) for x in s.split(","))
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Example of Dense Persistent GEMM on Blackwell."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mnkl",
|
|
type=parse_comma_separated_ints,
|
|
default=(256, 256, 512, 1),
|
|
help="mnkl dimensions (comma-separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--mma_tiler_mn",
|
|
type=parse_comma_separated_ints,
|
|
default=(128, 128),
|
|
help="Mma tile shape (comma-separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--cluster_shape_mn",
|
|
type=parse_comma_separated_ints,
|
|
default=(1, 1),
|
|
help="Cluster shape (comma-separated)",
|
|
)
|
|
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32)
|
|
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32)
|
|
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
|
|
parser.add_argument(
|
|
"--use_2cta_instrs",
|
|
action="store_true",
|
|
help="Enable 2CTA MMA instructions feature",
|
|
)
|
|
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
parser.add_argument(
|
|
"--use_tma_store", action="store_true", help="Use tma store or not"
|
|
)
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
|
)
|
|
parser.add_argument(
|
|
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
|
)
|
|
parser.add_argument(
|
|
"--iterations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of iterations to run the kernel",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
|
)
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if len(args.mnkl) != 4:
|
|
parser.error("--mnkl must contain exactly 4 values")
|
|
|
|
if len(args.mma_tiler_mn) != 2:
|
|
parser.error("--mma_tiler_mn must contain exactly 2 values")
|
|
|
|
if len(args.cluster_shape_mn) != 2:
|
|
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
|
|
|
run(
|
|
args.mnkl,
|
|
args.ab_dtype,
|
|
args.c_dtype,
|
|
args.acc_dtype,
|
|
args.a_major,
|
|
args.b_major,
|
|
args.c_major,
|
|
args.mma_tiler_mn,
|
|
args.cluster_shape_mn,
|
|
args.use_2cta_instrs,
|
|
args.use_tma_store,
|
|
args.tolerance,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
print("PASS")
|