* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
2928 lines
110 KiB
Python
2928 lines
110 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
from typing import Type, Tuple, Union
|
|
|
|
import cuda.bindings.driver as cuda
|
|
import torch
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
import cutlass.pipeline as pipeline
|
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
|
|
import math
|
|
|
|
|
|
"""
|
|
High-performance persistent blockwise dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture
|
|
using CUTE DSL.
|
|
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K")
|
|
- Matrix B is NxKxL, L is batch dimension, B can be column-major("K")
|
|
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
|
- Each block will apply the scale factor A
|
|
- Each row will apply the scale factor B
|
|
- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB
|
|
|
|
This GEMM kernel supports the following features:
|
|
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
|
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations
|
|
- Implements TMA multicast with cluster to reduce L2 memory traffic
|
|
- Support persistent tile scheduling to better overlap memory load/store with mma between tiles
|
|
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
|
|
|
|
This GEMM works as follows:
|
|
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
|
2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations.
|
|
2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
|
3. EPILOGUE warp:
|
|
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
|
- Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final
|
|
- Type convert Final matrix to output type.
|
|
- Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations.
|
|
|
|
SM100 tcgen05.mma instructions operate as follows:
|
|
- Read matrix A from SMEM
|
|
- Read matrix B from SMEM
|
|
- Write accumulator to TMEM
|
|
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/blackwell/blockwise_gemm/blockwise_gemm.py \
|
|
--ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \
|
|
--scale_dtype Float32 \
|
|
--mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \
|
|
--mnkl 4096,4096,4096,4
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/blackwell/blockwise_gemm/blockwise_gemm.py \
|
|
--ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \
|
|
--scale_dtype Float32 \
|
|
--mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \
|
|
--mnkl 4096,4096,4096,4
|
|
|
|
|
|
Constraints are same as dense_gemm.py:
|
|
* Supported input data types: fp8 (e4m3fn)
|
|
see detailed valid dtype combinations in below BlockwiseGemmKernel class documentation
|
|
* A/B tensor must have the same data type
|
|
* Mma tiler M must be 64/128/256
|
|
* Mma tiler N must be 128, align with the scaleB requirement
|
|
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
* Cluster shape M must be multiple of 2
|
|
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned
|
|
"""
|
|
|
|
|
|
class BlockwiseGemmKernel:
|
|
"""This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2)
|
|
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
|
|
|
:param acc_dtype: Data type for accumulation during computation
|
|
:type acc_dtype: type[cutlass.Numeric]
|
|
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
|
|
:type use_2cta_instrs: bool
|
|
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
|
|
:note: Supported A/B data types:
|
|
- Float8E4M3FN
|
|
|
|
:note: Supported accumulator data types:
|
|
- Float32
|
|
|
|
:note: Supported C data types:
|
|
- Float16/BFloat16
|
|
- Other data types are not supported for accuracy issues
|
|
|
|
:note: Constraints:
|
|
- MMA tiler M must be 64/128/256
|
|
- MMA tiler N must be 128
|
|
- Cluster shape M must be multiple of 2
|
|
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
|
|
|
Example:
|
|
>>> gemm = BlockwiseGemmKernel(
|
|
... acc_dtype=cutlass.Float32,
|
|
... use_2cta_instrs=True,
|
|
... mma_tiler_mn=(128, 128),
|
|
... cluster_shape_mn=(2, 2)
|
|
... )
|
|
>>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
use_2cta_instrs: bool,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
):
|
|
"""Initializes the configuration for a Blackwell blockwise dense GEMM kernel.
|
|
|
|
This configuration includes several key aspects:
|
|
|
|
1. MMA Instruction Settings (tcgen05):
|
|
- acc_dtype: Data types for MMA accumulator.
|
|
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
|
|
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
|
|
with cta_group=2 should be used.
|
|
|
|
2. Cluster Shape:
|
|
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
|
|
|
|
:param acc_dtype: Data type of the accumulator.
|
|
:type acc_dtype: type[cutlass.Numeric]
|
|
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
|
|
:type use_2cta_instrs: bool
|
|
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
"""
|
|
|
|
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
|
|
self.use_2cta_instrs = use_2cta_instrs
|
|
self.cluster_shape_mn = cluster_shape_mn
|
|
# K dimension is deferred in _setup_attributes
|
|
self.mma_tiler = (*mma_tiler_mn, 1)
|
|
|
|
self.cta_group = (
|
|
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
|
)
|
|
|
|
self.occupancy = 1
|
|
# Set specialized warp ids
|
|
self.acc_update_warp_id = (0, 1, 2, 3)
|
|
self.epilog_warp_id = (4, 5, 6, 7)
|
|
self.mma_warp_id = 8
|
|
self.tma_warp_id = 9
|
|
self.scale_warp_id = 10
|
|
self.sched_warp_id = 11
|
|
self.threads_per_warp = 32
|
|
self.threads_per_cta = self.threads_per_warp * len(
|
|
(
|
|
*self.acc_update_warp_id,
|
|
*self.epilog_warp_id,
|
|
self.mma_warp_id,
|
|
self.tma_warp_id,
|
|
self.scale_warp_id,
|
|
self.sched_warp_id,
|
|
)
|
|
)
|
|
self.threads_wo_sched = self.threads_per_warp * len(
|
|
(
|
|
*self.acc_update_warp_id,
|
|
*self.epilog_warp_id,
|
|
self.mma_warp_id,
|
|
self.tma_warp_id,
|
|
self.scale_warp_id,
|
|
)
|
|
)
|
|
self.num_regs_uniform_warps = 64
|
|
self.num_regs_sched_warps = 64
|
|
self.num_regs_epilogue_warps = 216
|
|
self.num_regs_acc_update_warps = 216
|
|
|
|
# Set barrier for cta sync, epilogue sync and tmem ptr sync
|
|
self.cta_sync_barrier = pipeline.NamedBarrier(
|
|
barrier_id=1,
|
|
num_threads=self.threads_per_cta,
|
|
)
|
|
self.epilog_sync_barrier = pipeline.NamedBarrier(
|
|
barrier_id=2,
|
|
num_threads=32 * len(self.epilog_warp_id),
|
|
)
|
|
self.tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
barrier_id=3,
|
|
num_threads=32
|
|
* len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)),
|
|
)
|
|
self.sched_sync_barrier = pipeline.NamedBarrier(
|
|
barrier_id=4,
|
|
num_threads=self.threads_per_warp,
|
|
)
|
|
self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
|
# TMEM offset for final accumulator
|
|
self.tmem_final_offset = 384
|
|
|
|
def _setup_attributes(self):
|
|
"""Set up configurations that are dependent on GEMM inputs
|
|
|
|
This method configures various attributes based on the input tensor properties
|
|
(data types, leading dimensions) and kernel settings:
|
|
- Configuring tiled MMA
|
|
- Computing MMA/cluster/tile shapes
|
|
- Computing cluster layout
|
|
- Computing multicast CTAs for A/B
|
|
- Computing epilogue subtile
|
|
- Setting up A/B/C stage counts in shared memory
|
|
- Computing A/B/C shared memory layout
|
|
- Computing tensor memory allocation columns
|
|
"""
|
|
# Configure tiled mma
|
|
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
self.a_dtype,
|
|
self.a_major_mode,
|
|
self.b_major_mode,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.mma_tiler[:2],
|
|
)
|
|
|
|
# Compute mma/cluster/tile shapes
|
|
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
|
mma_inst_tile_k = 4
|
|
self.mma_tiler = (
|
|
self.mma_tiler[0],
|
|
self.mma_tiler[1],
|
|
mma_inst_shape_k * mma_inst_tile_k,
|
|
)
|
|
self.cta_tile_shape_mnk = (
|
|
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
|
self.mma_tiler[1],
|
|
self.mma_tiler[2],
|
|
)
|
|
|
|
# Compute cluster layout
|
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
cute.make_layout((*self.cluster_shape_mn, 1)),
|
|
(tiled_mma.thr_id.shape,),
|
|
)
|
|
|
|
self.scale_granularity_m = 1
|
|
self.scale_granularity_n = 128
|
|
self.scale_granularity_k = 128
|
|
self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m
|
|
self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n
|
|
self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k
|
|
|
|
if self.scale_k_per_tile != 1:
|
|
raise ValueError("scale_k_per_tile must be 1")
|
|
if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]:
|
|
raise ValueError("scale_m_per_tile must be cta_tile_m")
|
|
if self.scale_n_per_tile != 1:
|
|
raise ValueError("scale_n_per_tile must be 1")
|
|
|
|
# Compute number of multicast CTAs for A/B
|
|
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
|
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
|
|
# Compute epilogue subtile
|
|
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
|
self.cta_tile_shape_mnk,
|
|
self.use_2cta_instrs,
|
|
self.c_layout,
|
|
self.c_dtype,
|
|
)
|
|
|
|
# Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory
|
|
(
|
|
self.num_acc_stage,
|
|
self.num_ab_stage,
|
|
self.num_c_stage,
|
|
self.num_scale_stage,
|
|
self.num_tile_stage,
|
|
) = self._compute_stages(
|
|
tiled_mma,
|
|
self.mma_tiler,
|
|
self.a_dtype,
|
|
self.b_dtype,
|
|
self.epi_tile,
|
|
self.c_dtype,
|
|
self.c_layout,
|
|
self.sfa_dtype,
|
|
self.sfb_dtype,
|
|
self.scale_m_per_tile * self.scale_k_per_tile,
|
|
self.scale_n_per_tile * self.scale_k_per_tile,
|
|
self.num_smem_capacity,
|
|
self.occupancy,
|
|
)
|
|
|
|
# Compute A/B/C/Scale shared memory layout
|
|
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
|
tiled_mma,
|
|
self.mma_tiler,
|
|
self.a_dtype,
|
|
self.num_ab_stage,
|
|
)
|
|
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
|
tiled_mma,
|
|
self.mma_tiler,
|
|
self.b_dtype,
|
|
self.num_ab_stage,
|
|
)
|
|
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
|
self.c_dtype,
|
|
self.c_layout,
|
|
self.epi_tile,
|
|
self.num_c_stage,
|
|
)
|
|
self.sfa_smem_layout_staged = cute.make_layout(
|
|
(
|
|
(self.scale_granularity_m, self.scale_m_per_tile),
|
|
(self.scale_granularity_k, self.scale_k_per_tile),
|
|
self.num_scale_stage,
|
|
),
|
|
stride=(
|
|
(0, self.scale_k_per_tile),
|
|
(0, 1),
|
|
self.scale_k_per_tile * self.scale_m_per_tile,
|
|
),
|
|
)
|
|
self.sfb_smem_layout_staged = cute.make_layout(
|
|
(
|
|
(self.scale_granularity_n, self.scale_n_per_tile),
|
|
(self.scale_granularity_k, self.scale_k_per_tile),
|
|
self.num_scale_stage,
|
|
),
|
|
stride=(
|
|
(0, self.scale_k_per_tile),
|
|
(0, 1),
|
|
self.scale_k_per_tile * self.scale_n_per_tile,
|
|
),
|
|
)
|
|
|
|
# Compute the number of tensor memory allocation columns
|
|
self.num_tmem_alloc_cols = 512
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
a: cute.Tensor,
|
|
b: cute.Tensor,
|
|
c: cute.Tensor,
|
|
sfa: cute.Tensor,
|
|
sfb: cute.Tensor,
|
|
max_active_clusters: cutlass.Constexpr,
|
|
stream: cuda.CUstream,
|
|
epilogue_op: cutlass.Constexpr = lambda x: x,
|
|
):
|
|
"""Execute the GEMM operation in steps:
|
|
- Setup static attributes before smem/grid/tma computation
|
|
- Setup TMA load/store atoms and tensors
|
|
- Compute grid size with regard to hardware constraints
|
|
- Define shared storage for kernel
|
|
- Launch the kernel synchronously
|
|
|
|
:param a: Input tensor A
|
|
:type a: cute.Tensor
|
|
:param b: Input tensor B
|
|
:type b: cute.Tensor
|
|
:param c: Output tensor C
|
|
:type c: cute.Tensor
|
|
:param sfa: Scale factor tensor A
|
|
:type sfa: cute.Tensor
|
|
:param sfb: Scale factor tensor B
|
|
:type sfb: cute.Tensor
|
|
:param max_active_clusters: Maximum number of active clusters
|
|
:type max_active_clusters: cutlass.Constexpr
|
|
:param stream: CUDA stream for asynchronous execution
|
|
:type stream: cuda.CUstream
|
|
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
|
|
:type epilogue_op: cutlass.Constexpr
|
|
:raises TypeError: If input data types are incompatible with the MMA instruction.
|
|
"""
|
|
# Setup static attributes before smem/grid/tma computation
|
|
self.a_dtype: Type[cutlass.Numeric] = a.element_type
|
|
self.b_dtype: Type[cutlass.Numeric] = b.element_type
|
|
self.c_dtype: Type[cutlass.Numeric] = c.element_type
|
|
self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type
|
|
self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type
|
|
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
|
|
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
|
|
self.c_layout = utils.LayoutEnum.from_tensor(c)
|
|
|
|
# Check if input data types are compatible with MMA instruction
|
|
if cutlass.const_expr(self.a_dtype != self.b_dtype):
|
|
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
|
|
|
|
# Setup attributes that dependent on gemm inputs
|
|
self._setup_attributes()
|
|
|
|
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
|
self.a_dtype,
|
|
self.a_major_mode,
|
|
self.b_major_mode,
|
|
self.acc_dtype,
|
|
self.cta_group,
|
|
self.mma_tiler[:2],
|
|
)
|
|
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
|
|
|
# Setup TMA load for A
|
|
a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast)
|
|
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
|
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
|
a_op,
|
|
a,
|
|
a_smem_layout,
|
|
self.mma_tiler,
|
|
tiled_mma,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
# Setup TMA load for B
|
|
b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast)
|
|
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
|
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
|
b_op,
|
|
b,
|
|
b_smem_layout,
|
|
self.mma_tiler,
|
|
tiled_mma,
|
|
self.cluster_layout_vmnk.shape,
|
|
internal_type=(
|
|
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
|
|
),
|
|
)
|
|
|
|
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
|
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
|
|
|
|
# Setup TMA store for C
|
|
tma_atom_c = None
|
|
tma_tensor_c = None
|
|
c_cta_v_layout = cute.composition(
|
|
cute.make_identity_layout(c.shape), self.epi_tile
|
|
)
|
|
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
|
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileS2GOp(),
|
|
c,
|
|
epi_smem_layout,
|
|
c_cta_v_layout,
|
|
)
|
|
|
|
tensor_sfa = cute.make_tensor(
|
|
sfa.iterator,
|
|
cute.make_layout(
|
|
(
|
|
(self.scale_granularity_m, sfa.shape[0]),
|
|
(self.scale_granularity_k, sfa.shape[1]),
|
|
sfa.shape[2],
|
|
),
|
|
stride=(
|
|
(0, sfa.layout.stride[0]),
|
|
(0, sfa.layout.stride[1]),
|
|
sfa.layout.stride[2],
|
|
),
|
|
),
|
|
)
|
|
tensor_sfb = cute.make_tensor(
|
|
sfb.iterator,
|
|
cute.make_layout(
|
|
(
|
|
(self.scale_granularity_n, sfb.shape[0]),
|
|
(self.scale_granularity_k, sfb.shape[1]),
|
|
sfb.shape[2],
|
|
),
|
|
stride=(
|
|
(0, sfb.layout.stride[0]),
|
|
(0, sfb.layout.stride[1]),
|
|
sfb.layout.stride[2],
|
|
),
|
|
),
|
|
)
|
|
|
|
# Compute grid size
|
|
self.tile_sched_params, grid = self._compute_grid(
|
|
c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters
|
|
)
|
|
|
|
self.buffer_align_bytes = 1024
|
|
|
|
c_smem_size = cute.cosize(self.c_smem_layout_staged.outer)
|
|
|
|
# Define shared storage for kernel
|
|
@cute.struct
|
|
class SharedStorage:
|
|
# (bidx, bidy, bidz, valid)
|
|
sInfo: cute.struct.Align[
|
|
cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage],
|
|
# 1 byte alignment
|
|
1,
|
|
]
|
|
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
|
scale_mbar_ptr: cute.struct.MemRange[
|
|
cutlass.Int64, self.num_scale_stage * 2
|
|
]
|
|
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
|
tile_info_mbar_ptr: cute.struct.MemRange[
|
|
cutlass.Int64, self.num_tile_stage * 2
|
|
]
|
|
epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2]
|
|
tmem_dealloc_mbar_ptr: cutlass.Int64
|
|
tmem_holding_buf: cutlass.Int32
|
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
sC: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.c_dtype,
|
|
c_smem_size,
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
sA: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
sB: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
# (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage)
|
|
sSFA: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
|
|
sSFB: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
|
|
self.shared_storage = SharedStorage
|
|
|
|
# Launch the kernel synchronously
|
|
self.kernel(
|
|
tiled_mma,
|
|
tma_atom_a,
|
|
tma_tensor_a,
|
|
tma_atom_b,
|
|
tma_tensor_b,
|
|
tma_atom_c,
|
|
tma_tensor_c,
|
|
tensor_sfa,
|
|
tensor_sfb,
|
|
self.cluster_layout_vmnk,
|
|
self.a_smem_layout_staged,
|
|
self.b_smem_layout_staged,
|
|
self.c_smem_layout_staged,
|
|
self.sfa_smem_layout_staged,
|
|
self.sfb_smem_layout_staged,
|
|
self.epi_tile,
|
|
self.tile_sched_params,
|
|
epilogue_op,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=(*self.cluster_shape_mn, 1),
|
|
smem=self.shared_storage.size_in_bytes(),
|
|
stream=stream,
|
|
min_blocks_per_mp=1,
|
|
)
|
|
return
|
|
|
|
# GPU device kernel
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
tiled_mma: cute.TiledMma,
|
|
tma_atom_a: cute.CopyAtom,
|
|
mA_mkl: cute.Tensor,
|
|
tma_atom_b: cute.CopyAtom,
|
|
mB_nkl: cute.Tensor,
|
|
tma_atom_c: cute.CopyAtom,
|
|
mC_mnl: cute.Tensor,
|
|
mSFA_mkl: cute.Tensor,
|
|
mSFB_nkl: cute.Tensor,
|
|
cluster_layout_vmnk: cute.Layout,
|
|
a_smem_layout_staged: cute.ComposedLayout,
|
|
b_smem_layout_staged: cute.ComposedLayout,
|
|
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
|
|
sfa_smem_layout_staged: cute.Layout,
|
|
sfb_smem_layout_staged: cute.Layout,
|
|
epi_tile: cute.Tile,
|
|
tile_sched_params: utils.PersistentTileSchedulerParams,
|
|
epilogue_op: cutlass.Constexpr,
|
|
):
|
|
"""
|
|
GPU device kernel performing the Persistent batched GEMM computation.
|
|
"""
|
|
warp_idx = cute.arch.warp_idx()
|
|
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
|
lane_idx = cute.arch.lane_idx()
|
|
|
|
#
|
|
# Prefetch tma desc
|
|
#
|
|
if warp_idx == self.tma_warp_id:
|
|
cpasync.prefetch_descriptor(tma_atom_a)
|
|
cpasync.prefetch_descriptor(tma_atom_b)
|
|
cpasync.prefetch_descriptor(tma_atom_c)
|
|
|
|
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
|
|
|
#
|
|
# Setup cta/thread coordinates
|
|
#
|
|
# Coords inside cluster
|
|
bidx, bidy, bidz = cute.arch.block_idx()
|
|
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
|
is_leader_cta = mma_tile_coord_v == 0
|
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
|
cute.arch.block_idx_in_cluster()
|
|
)
|
|
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
|
|
cta_rank_in_cluster
|
|
)
|
|
# Coord inside cta
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
|
|
#
|
|
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
|
|
#
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(self.shared_storage)
|
|
|
|
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
|
|
tmem_holding_buf = storage.tmem_holding_buf
|
|
|
|
# Initialize mainloop ab_pipeline (barrier) and states
|
|
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, num_tma_producer
|
|
)
|
|
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
|
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
|
|
num_stages=self.num_ab_stage,
|
|
producer_group=ab_pipeline_producer_group,
|
|
consumer_group=ab_pipeline_consumer_group,
|
|
tx_count=self.num_tma_load_bytes,
|
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
)
|
|
|
|
# Initialize mainloop scale_pipeline (barrier) and states
|
|
scale_pipeline_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * 1,
|
|
)
|
|
scale_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * len(self.epilog_warp_id),
|
|
)
|
|
scale_pipeline = pipeline.PipelineCpAsync.create(
|
|
barrier_storage=storage.scale_mbar_ptr.data_ptr(),
|
|
num_stages=self.num_scale_stage,
|
|
producer_group=scale_pipeline_producer_group,
|
|
consumer_group=scale_pipeline_consumer_group,
|
|
)
|
|
|
|
# Initialize acc_pipeline (barrier) and states
|
|
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
|
num_acc_consumer_threads = len(self.epilog_warp_id) * (
|
|
2 if use_2cta_instrs else 1
|
|
)
|
|
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, num_acc_consumer_threads
|
|
)
|
|
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
|
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
|
num_stages=self.num_acc_stage,
|
|
producer_group=acc_pipeline_producer_group,
|
|
consumer_group=acc_pipeline_consumer_group,
|
|
cta_layout_vmnk=cluster_layout_vmnk,
|
|
)
|
|
|
|
# Initialize epilogue pipeline (barrier) and states
|
|
epi_pipeline_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * len(self.acc_update_warp_id),
|
|
)
|
|
epi_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * len(self.epilog_warp_id),
|
|
)
|
|
epi_pipeline = pipeline.PipelineAsync.create(
|
|
barrier_storage=storage.epi_mbar_ptr.data_ptr(),
|
|
num_stages=1,
|
|
producer_group=epi_pipeline_producer_group,
|
|
consumer_group=epi_pipeline_consumer_group,
|
|
)
|
|
|
|
# Initialize tile info pipeline (barrier) and states
|
|
tile_info_pipeline_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_per_warp * 1,
|
|
)
|
|
tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.threads_wo_sched,
|
|
)
|
|
tile_info_pipeline = pipeline.PipelineAsync.create(
|
|
barrier_storage=storage.tile_info_mbar_ptr.data_ptr(),
|
|
num_stages=self.num_tile_stage,
|
|
producer_group=tile_info_pipeline_producer_group,
|
|
consumer_group=tile_info_pipeline_consumer_group,
|
|
)
|
|
|
|
# Tensor memory dealloc barrier init
|
|
tmem = utils.TmemAllocator(
|
|
storage.tmem_holding_buf,
|
|
barrier_for_retrieve=self.tmem_alloc_barrier,
|
|
allocator_warp_id=self.epilog_warp_id[0],
|
|
is_two_cta=use_2cta_instrs,
|
|
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
|
)
|
|
|
|
# Cluster arrive after barrier init
|
|
if cute.size(self.cluster_shape_mn) > 1:
|
|
cute.arch.cluster_arrive_relaxed()
|
|
|
|
#
|
|
# Setup smem tensor A/B/C/Scale
|
|
#
|
|
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
|
sC = storage.sC.get_tensor(
|
|
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
sA = storage.sA.get_tensor(
|
|
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
|
)
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
sB = storage.sB.get_tensor(
|
|
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
|
)
|
|
# (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage)
|
|
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
|
|
# (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage)
|
|
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
|
|
# (bidx, bidy, bidz, valid)
|
|
info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4))
|
|
sInfo = storage.sInfo.get_tensor(info_layout)
|
|
|
|
#
|
|
# Compute multicast mask for A/B buffer full
|
|
#
|
|
a_full_mcast_mask = None
|
|
b_full_mcast_mask = None
|
|
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
|
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
|
)
|
|
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
|
|
)
|
|
|
|
#
|
|
# Local_tile partition global tensors
|
|
#
|
|
# (bM, bK, loopM, loopK, loopL)
|
|
gA_mkl = cute.local_tile(
|
|
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
|
)
|
|
# (bN, bK, loopN, loopK, loopL)
|
|
gB_nkl = cute.local_tile(
|
|
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
|
|
)
|
|
# (bM, bN, loopM, loopN, loopL)
|
|
gC_mnl = cute.local_tile(
|
|
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
|
)
|
|
# (bM, bK, loopM, loopK, loopL)
|
|
gSFA_mkl = cute.local_tile(
|
|
mSFA_mkl,
|
|
cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)),
|
|
(None, None, None),
|
|
)
|
|
# (bN, bK, loopN, loopK, loopL)
|
|
gSFB_nkl = cute.local_tile(
|
|
mSFB_nkl,
|
|
cute.slice_(self.cta_tile_shape_mnk, (0, None, None)),
|
|
(None, None, None),
|
|
)
|
|
# coordinate
|
|
cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl))
|
|
cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl))
|
|
# (bM, bK, loopM, loopK, loopL)
|
|
cSFA = cute.local_tile(
|
|
cSFA_mkl,
|
|
cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)),
|
|
(None, None, None),
|
|
)
|
|
# (bN, bK, loopN, loopK, loopL)
|
|
cSFB = cute.local_tile(
|
|
cSFB_nkl,
|
|
cute.slice_(self.cta_tile_shape_mnk, (0, None, None)),
|
|
(None, None, None),
|
|
)
|
|
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
|
|
|
#
|
|
# Partition global tensor for TiledMMA_A/B/C
|
|
#
|
|
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
|
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
|
|
tCgA = thr_mma.partition_A(gA_mkl)
|
|
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
|
|
tCgB = thr_mma.partition_B(gB_nkl)
|
|
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
|
|
tCgC = thr_mma.partition_C(gC_mnl)
|
|
|
|
# scale viewed as C tensor
|
|
sSFA_view_as_C_layout = cute.make_layout(
|
|
(
|
|
(self.scale_granularity_m, self.scale_m_per_tile),
|
|
self.cta_tile_shape_mnk[1],
|
|
self.num_scale_stage,
|
|
),
|
|
stride=((0, 1), 0, self.scale_m_per_tile),
|
|
)
|
|
sSFB_view_as_C_layout = cute.make_layout(
|
|
(
|
|
self.cta_tile_shape_mnk[0],
|
|
(self.scale_granularity_n, self.scale_n_per_tile),
|
|
self.num_scale_stage,
|
|
),
|
|
stride=(0, (0, 1), self.scale_n_per_tile),
|
|
)
|
|
sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout)
|
|
sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout)
|
|
|
|
#
|
|
# Partition global/shared tensor for TMA load A/B
|
|
#
|
|
# TMA load A partition_S/D
|
|
a_cta_layout = cute.make_layout(
|
|
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
|
)
|
|
# ((atom_v, rest_v), STAGE)
|
|
# ((atom_v, rest_v), loopM, loopK, loopL)
|
|
tAsA, tAgA = cpasync.tma_partition(
|
|
tma_atom_a,
|
|
block_in_cluster_coord_vmnk[2],
|
|
a_cta_layout,
|
|
cute.group_modes(sA, 0, 3),
|
|
cute.group_modes(tCgA, 0, 3),
|
|
)
|
|
# TMA load B partition_S/D
|
|
b_cta_layout = cute.make_layout(
|
|
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
|
)
|
|
# ((atom_v, rest_v), STAGE)
|
|
# ((atom_v, rest_v), loopM, loopK, loopL)
|
|
tBsB, tBgB = cpasync.tma_partition(
|
|
tma_atom_b,
|
|
block_in_cluster_coord_vmnk[1],
|
|
b_cta_layout,
|
|
cute.group_modes(sB, 0, 3),
|
|
cute.group_modes(tCgB, 0, 3),
|
|
)
|
|
|
|
#
|
|
# Partition global/shared tensor for TMA load A/B
|
|
#
|
|
# load scaleA/scaleB
|
|
atom_copy = cute.make_copy_atom(
|
|
cute.nvgpu.cpasync.CopyG2SOp(),
|
|
mSFA_mkl.element_type,
|
|
num_bits_per_copy=mSFA_mkl.element_type.width,
|
|
)
|
|
tiled_copy_sfa = cute.make_tiled_copy_tv(
|
|
atom_copy, cute.make_layout((32,)), cute.make_layout((1,))
|
|
)
|
|
tiled_copy_sfb = cute.make_tiled_copy_tv(
|
|
atom_copy, cute.make_layout((32,)), cute.make_layout((1,))
|
|
)
|
|
thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx)
|
|
thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx)
|
|
# ((atom_v, rest_v), STAGE)
|
|
# ((atom_v, rest_v), loopM, loopK, loopL)
|
|
tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl)
|
|
tAsSFA = thr_copy_sfa.partition_D(sSFA)
|
|
tAcSFA = thr_copy_sfa.partition_S(cSFA)
|
|
# ((atom_v, rest_v), STAGE)
|
|
# ((atom_v, rest_v), loopN, loopK, loopL)
|
|
tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl)
|
|
tBsSFB = thr_copy_sfb.partition_D(sSFB)
|
|
tBcSFB = thr_copy_sfb.partition_S(cSFB)
|
|
|
|
#
|
|
# Partition shared/tensor memory tensor for TiledMMA_A/B/C
|
|
#
|
|
# (MMA, MMA_M, MMA_K, STAGE)
|
|
tCrA = tiled_mma.make_fragment_A(sA)
|
|
# (MMA, MMA_N, MMA_K, STAGE)
|
|
tCrB = tiled_mma.make_fragment_B(sB)
|
|
# (MMA, MMA_M, MMA_N)
|
|
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_fake = tiled_mma.make_fragment_C(
|
|
cute.append(acc_shape, self.num_acc_stage)
|
|
)
|
|
|
|
#
|
|
# Cluster wait before tensor memory alloc
|
|
#
|
|
if cute.size(self.cluster_shape_mn) > 1:
|
|
cute.arch.cluster_wait()
|
|
else:
|
|
self.cta_sync_barrier.arrive_and_wait()
|
|
|
|
#
|
|
# Specialized Schedule warp
|
|
#
|
|
if warp_idx == self.sched_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps)
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
# First tile
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
tile_info_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.num_tile_stage
|
|
)
|
|
|
|
while work_tile.is_valid_tile:
|
|
# query next tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
|
|
# acquire tile info pipeline
|
|
tile_info_pipeline.producer_acquire(tile_info_producer_state)
|
|
|
|
# store the tile info
|
|
cur_tile_coord = work_tile.tile_idx
|
|
with cute.arch.elect_one():
|
|
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
|
|
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
|
|
sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2]
|
|
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
|
|
work_tile.is_valid_tile
|
|
)
|
|
|
|
# fence view async shared
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
self.sched_sync_barrier.arrive_and_wait()
|
|
# commit tile info pipeline
|
|
tile_info_pipeline.producer_commit(tile_info_producer_state)
|
|
tile_info_producer_state.advance()
|
|
|
|
tile_info_pipeline.producer_tail(tile_info_producer_state)
|
|
|
|
#
|
|
# Specialized TMA load warp
|
|
#
|
|
if warp_idx == self.tma_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
# First tile
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
ab_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
|
)
|
|
|
|
tile_info_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_tile_stage
|
|
)
|
|
|
|
# get the first tile info
|
|
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
# initialize the tile info
|
|
tile_info[0] = cur_tile_coord[0]
|
|
tile_info[1] = cur_tile_coord[1]
|
|
tile_info[2] = cur_tile_coord[2]
|
|
tile_info[3] = work_tile.is_valid_tile
|
|
|
|
is_valid_tile = cutlass.Boolean(1)
|
|
is_valid_tile = tile_info[3] == 1
|
|
|
|
while is_valid_tile:
|
|
mma_tile_coord_mnl = (
|
|
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
|
|
tile_info[1],
|
|
tile_info[2],
|
|
)
|
|
#
|
|
# Slice to per mma tile index
|
|
#
|
|
# ((atom_v, rest_v), loopK)
|
|
tAgA_slice = tAgA[
|
|
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
|
|
]
|
|
# ((atom_v, rest_v), loopK)
|
|
tBgB_slice = tBgB[
|
|
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
|
]
|
|
|
|
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
|
ab_producer_state.reset_count()
|
|
peek_ab_empty_status = cutlass.Boolean(1)
|
|
if ab_producer_state.count < k_tile_cnt:
|
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
|
ab_producer_state
|
|
)
|
|
#
|
|
# Tma load loop
|
|
#
|
|
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
|
tAgA_k = tAgA_slice[(None, ab_producer_state.count)]
|
|
tBgB_k = tBgB_slice[(None, ab_producer_state.count)]
|
|
tAsA_pipe = tAsA[(None, ab_producer_state.index)]
|
|
tBsB_pipe = tBsB[(None, ab_producer_state.index)]
|
|
|
|
tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state)
|
|
|
|
# Conditionally wait for AB buffer empty
|
|
ab_pipeline.producer_acquire(
|
|
ab_producer_state, peek_ab_empty_status
|
|
)
|
|
|
|
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
|
ab_producer_state.advance()
|
|
peek_ab_empty_status = cutlass.Boolean(1)
|
|
if ab_producer_state.count < k_tile_cnt:
|
|
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
|
ab_producer_state
|
|
)
|
|
|
|
# TMA load A/B
|
|
cute.copy(
|
|
tma_atom_a,
|
|
tAgA_k,
|
|
tAsA_pipe,
|
|
tma_bar_ptr=tma_bar,
|
|
mcast_mask=a_full_mcast_mask,
|
|
)
|
|
cute.copy(
|
|
tma_atom_b,
|
|
tBgB_k,
|
|
tBsB_pipe,
|
|
tma_bar_ptr=tma_bar,
|
|
mcast_mask=b_full_mcast_mask,
|
|
)
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
|
for idx in cutlass.range(4, unroll_full=True):
|
|
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
|
|
is_valid_tile = tile_info[3] == 1
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
|
tile_info_consumer_state.advance()
|
|
|
|
#
|
|
# Wait A/B buffer empty
|
|
#
|
|
ab_pipeline.producer_tail(ab_producer_state)
|
|
|
|
#
|
|
# Specialized Scale load warp
|
|
#
|
|
if warp_idx == self.scale_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
# First tile
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
scale_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.num_scale_stage
|
|
)
|
|
|
|
tile_info_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_tile_stage
|
|
)
|
|
|
|
# get the first tile info
|
|
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
# initialize the tile info
|
|
tile_info[0] = cur_tile_coord[0]
|
|
tile_info[1] = cur_tile_coord[1]
|
|
tile_info[2] = cur_tile_coord[2]
|
|
tile_info[3] = work_tile.is_valid_tile
|
|
|
|
is_valid_tile = cutlass.Boolean(1)
|
|
is_valid_tile = tile_info[3] == 1
|
|
|
|
while is_valid_tile:
|
|
#
|
|
# Prepare the mask for scaleA/scaleB
|
|
#
|
|
tApSFA = cute.make_rmem_tensor(
|
|
cute.make_layout(
|
|
cute.filter_zeros(
|
|
cute.slice_(tAsSFA, (None, None, None, 0))
|
|
).shape
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
tBpSFB = cute.make_rmem_tensor(
|
|
cute.make_layout(
|
|
cute.filter_zeros(
|
|
cute.slice_(tBsSFB, (None, None, None, 0))
|
|
).shape
|
|
),
|
|
cutlass.Boolean,
|
|
)
|
|
|
|
# Peek (try_wait) SCALE buffer empty
|
|
scale_producer_state.reset_count()
|
|
peek_scale_empty_status = cutlass.Boolean(1)
|
|
if scale_producer_state.count < k_tile_cnt:
|
|
peek_scale_empty_status = scale_pipeline.producer_try_acquire(
|
|
scale_producer_state
|
|
)
|
|
|
|
#
|
|
# load loop
|
|
#
|
|
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
|
#
|
|
# Slice to per mma tile index
|
|
#
|
|
tAsSFA_pipe = cute.filter_zeros(
|
|
tAsSFA[(None, None, None, scale_producer_state.index)]
|
|
)
|
|
tBsSFB_pipe = cute.filter_zeros(
|
|
tBsSFB[(None, None, None, scale_producer_state.index)]
|
|
)
|
|
tAgSFA_k = cute.filter_zeros(
|
|
tAgSFA_mkl[
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
tile_info[0],
|
|
scale_producer_state.count,
|
|
tile_info[2],
|
|
)
|
|
]
|
|
)
|
|
tBgSFB_k = cute.filter_zeros(
|
|
tBgSFB_nkl[
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
tile_info[1],
|
|
scale_producer_state.count,
|
|
tile_info[2],
|
|
)
|
|
]
|
|
)
|
|
|
|
tAcSFA_compact = cute.filter_zeros(
|
|
cute.slice_(
|
|
tAcSFA,
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
tile_info[0],
|
|
scale_producer_state.count,
|
|
tile_info[2],
|
|
),
|
|
)
|
|
)
|
|
tBcSFB_compact = cute.filter_zeros(
|
|
cute.slice_(
|
|
tBcSFB,
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
tile_info[1],
|
|
scale_producer_state.count,
|
|
tile_info[2],
|
|
),
|
|
)
|
|
)
|
|
for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])):
|
|
tApSFA[((0, 0), i, (0, 0))] = cute.elem_less(
|
|
tAcSFA_compact[(i)][0], mSFA_mkl.shape[0]
|
|
)
|
|
for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])):
|
|
tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less(
|
|
tBcSFB_compact[(i)][0], mSFB_nkl.shape[0]
|
|
)
|
|
|
|
# Conditionally wait for Scale buffer empty
|
|
scale_pipeline.producer_acquire(
|
|
scale_producer_state, peek_scale_empty_status
|
|
)
|
|
|
|
# load scaleA/scaleB
|
|
cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA)
|
|
cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB)
|
|
|
|
scale_pipeline.producer_commit(scale_producer_state)
|
|
|
|
# Peek (try_wait) Scale buffer empty
|
|
scale_producer_state.advance()
|
|
peek_scale_empty_status = cutlass.Boolean(1)
|
|
if scale_producer_state.count < k_tile_cnt:
|
|
peek_scale_empty_status = scale_pipeline.producer_try_acquire(
|
|
scale_producer_state
|
|
)
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
|
for idx in cutlass.range(4, unroll_full=True):
|
|
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
|
|
is_valid_tile = tile_info[3] == 1
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
|
tile_info_consumer_state.advance()
|
|
|
|
#
|
|
# Wait Scale buffer empty
|
|
#
|
|
scale_pipeline.producer_tail(scale_producer_state)
|
|
|
|
#
|
|
# Specialized MMA warp
|
|
#
|
|
if warp_idx == self.mma_warp_id:
|
|
cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps)
|
|
#
|
|
# Bar sync for retrieve tensor memory ptr from shared mem
|
|
#
|
|
tmem.wait_for_alloc()
|
|
|
|
#
|
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
#
|
|
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
|
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
ab_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
|
)
|
|
acc_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
|
)
|
|
|
|
tile_info_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_tile_stage
|
|
)
|
|
|
|
# get the first tile info
|
|
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
# initialize the tile info
|
|
tile_info[0] = cur_tile_coord[0]
|
|
tile_info[1] = cur_tile_coord[1]
|
|
tile_info[2] = cur_tile_coord[2]
|
|
tile_info[3] = work_tile.is_valid_tile
|
|
|
|
is_valid_tile = cutlass.Boolean(1)
|
|
is_valid_tile = tile_info[3] == 1
|
|
|
|
while is_valid_tile:
|
|
# Peek (try_wait) AB buffer full for k_tile = 0
|
|
ab_consumer_state.reset_count()
|
|
peek_ab_full_status = cutlass.Boolean(1)
|
|
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
|
ab_consumer_state
|
|
)
|
|
|
|
# Peek (try_wait) Acc buffer empty for k_tile = 0
|
|
acc_producer_state.reset_count()
|
|
peek_acc_empty_status = cutlass.Boolean(1)
|
|
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
|
peek_acc_empty_status = acc_pipeline.producer_try_acquire(
|
|
acc_producer_state
|
|
)
|
|
|
|
#
|
|
# Mma mainloop
|
|
#
|
|
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
|
# Set tensor memory buffer for current tile
|
|
# (MMA, MMA_M, MMA_N)
|
|
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
|
|
|
#
|
|
# Wait for accumulator buffer empty
|
|
#
|
|
if is_leader_cta:
|
|
acc_pipeline.producer_acquire(
|
|
acc_producer_state, peek_acc_empty_status
|
|
)
|
|
|
|
#
|
|
# Reset the ACCUMULATE field for each tile
|
|
#
|
|
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
|
|
|
if is_leader_cta:
|
|
# Conditionally wait for AB buffer full
|
|
ab_pipeline.consumer_wait(
|
|
ab_consumer_state, peek_ab_full_status
|
|
)
|
|
|
|
# tCtAcc += tCrA * tCrB
|
|
num_kblocks = cute.size(tCrA, mode=[2])
|
|
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
|
kblock_coord = (
|
|
None,
|
|
None,
|
|
kblock_idx,
|
|
ab_consumer_state.index,
|
|
)
|
|
|
|
cute.gemm(
|
|
tiled_mma,
|
|
tCtAcc,
|
|
tCrA[kblock_coord],
|
|
tCrB[kblock_coord],
|
|
tCtAcc,
|
|
)
|
|
# Enable accumulate on tCtAcc after first kblock
|
|
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
|
|
# Async arrive AB buffer empty
|
|
ab_pipeline.consumer_release(ab_consumer_state)
|
|
|
|
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
|
ab_consumer_state.advance()
|
|
peek_ab_full_status = cutlass.Boolean(1)
|
|
if ab_consumer_state.count < k_tile_cnt:
|
|
if is_leader_cta:
|
|
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
|
ab_consumer_state
|
|
)
|
|
|
|
#
|
|
# Async arrive accumulator buffer full(each kblock)
|
|
#
|
|
if is_leader_cta:
|
|
acc_pipeline.producer_commit(acc_producer_state)
|
|
|
|
# Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1
|
|
acc_producer_state.advance()
|
|
if acc_producer_state.count < k_tile_cnt:
|
|
if is_leader_cta:
|
|
peek_acc_empty_status = acc_pipeline.producer_try_acquire(
|
|
acc_producer_state
|
|
)
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
|
for idx in cutlass.range(4, unroll_full=True):
|
|
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
|
|
is_valid_tile = tile_info[3] == 1
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
|
tile_info_consumer_state.advance()
|
|
|
|
#
|
|
# Wait for accumulator buffer empty
|
|
#
|
|
acc_pipeline.producer_tail(acc_producer_state)
|
|
|
|
#
|
|
# Specialized acc update warps
|
|
#
|
|
if warp_idx <= self.acc_update_warp_id[-1]:
|
|
cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps)
|
|
#
|
|
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
#
|
|
tmem.wait_for_alloc()
|
|
|
|
#
|
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
#
|
|
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
|
tCtAcc_final = cute.make_tensor(
|
|
tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout
|
|
)
|
|
|
|
#
|
|
# Partition for epilogue
|
|
#
|
|
epi_tidx = tidx % 128
|
|
(
|
|
tiled_copy_t2r,
|
|
tiled_copy_r2t,
|
|
tTR_tAcc_base,
|
|
tTR_rAcc,
|
|
tTR_rAcc_final,
|
|
tTR_sSFA,
|
|
tTR_sSFB,
|
|
tRT_rAcc,
|
|
tRT_tAcc_base,
|
|
) = self.acc_update_tmem_copy_and_partition(
|
|
epi_tidx,
|
|
tCtAcc_base,
|
|
tCtAcc_final,
|
|
tCgC,
|
|
sSFA_view_as_C,
|
|
sSFB_view_as_C,
|
|
epi_tile,
|
|
)
|
|
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
acc_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
|
)
|
|
|
|
scale_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_scale_stage
|
|
)
|
|
|
|
epi_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, 1
|
|
)
|
|
|
|
tile_info_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_tile_stage
|
|
)
|
|
|
|
# get the first tile info
|
|
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
# initialize the tile info
|
|
tile_info[0] = cur_tile_coord[0]
|
|
tile_info[1] = cur_tile_coord[1]
|
|
tile_info[2] = cur_tile_coord[2]
|
|
tile_info[3] = work_tile.is_valid_tile
|
|
|
|
is_valid_tile = cutlass.Boolean(1)
|
|
is_valid_tile = tile_info[3] == 1
|
|
|
|
while is_valid_tile:
|
|
# initialize the final accumulator
|
|
tTR_rAcc_final.fill(0.0)
|
|
|
|
tTR_rSFA = cute.make_rmem_tensor(
|
|
cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape,
|
|
self.acc_dtype,
|
|
)
|
|
tTR_rSFB = cute.make_rmem_tensor(
|
|
cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape,
|
|
self.acc_dtype,
|
|
)
|
|
|
|
scale_consumer_state.reset_count()
|
|
peek_scale_full_status = cutlass.Boolean(1)
|
|
if scale_consumer_state.count < k_tile_cnt:
|
|
peek_scale_full_status = scale_pipeline.consumer_try_wait(
|
|
scale_consumer_state
|
|
)
|
|
|
|
acc_consumer_state.reset_count()
|
|
peek_acc_full_status = cutlass.Boolean(1)
|
|
if acc_consumer_state.count < k_tile_cnt:
|
|
peek_acc_full_status = acc_pipeline.consumer_try_wait(
|
|
acc_consumer_state
|
|
)
|
|
|
|
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
|
# Set tensor memory buffer for current tile
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
|
tTR_tAcc = tTR_tAcc_base[
|
|
(None, None, None, None, None, acc_consumer_state.index)
|
|
]
|
|
|
|
#
|
|
# Wait for scale buffer full
|
|
#
|
|
scale_pipeline.consumer_wait(
|
|
scale_consumer_state, peek_scale_full_status
|
|
)
|
|
|
|
tTR_sSFA_slice = cute.slice_(
|
|
tTR_sSFA,
|
|
(None, None, None, 0, None, scale_consumer_state.index),
|
|
)
|
|
tTR_sSFB_slice = cute.slice_(
|
|
tTR_sSFB,
|
|
(None, None, None, 0, None, scale_consumer_state.index),
|
|
)
|
|
|
|
scale_atom_copy = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
self.acc_dtype,
|
|
num_bits_per_copy=self.acc_dtype.width,
|
|
)
|
|
|
|
cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA)
|
|
cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB)
|
|
|
|
#
|
|
# Wait for accumulator buffer full
|
|
#
|
|
|
|
acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status)
|
|
|
|
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
|
|
#
|
|
# Update accumulator by scale factor in subtiles
|
|
#
|
|
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
|
for subtile_idx in cutlass.range(subtile_cnt):
|
|
#
|
|
# Load accumulator from tensor memory buffer to register
|
|
#
|
|
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
|
|
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
|
|
|
#
|
|
# Update accumulator by scale factor
|
|
#
|
|
tTR_rAcc_subtile = tTR_rAcc_final[
|
|
(None, None, None, subtile_idx)
|
|
]
|
|
tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)]
|
|
tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)]
|
|
|
|
acc_vec = tTR_rAcc.load()
|
|
final_vec = tTR_rAcc_subtile.load()
|
|
scale_a = tTR_rSFA_subtile.load()
|
|
scale_b = tTR_rSFB_subtile.load()
|
|
scale = scale_a * scale_b
|
|
final_vec = acc_vec * scale + final_vec
|
|
tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype))
|
|
|
|
#
|
|
# Async arrive accumulator buffer empty
|
|
#
|
|
scale_pipeline.consumer_release(scale_consumer_state)
|
|
scale_consumer_state.advance()
|
|
|
|
peek_scale_full_status = cutlass.Boolean(1)
|
|
if scale_consumer_state.count < k_tile_cnt:
|
|
peek_scale_full_status = scale_pipeline.consumer_try_wait(
|
|
scale_consumer_state
|
|
)
|
|
#
|
|
# Async arrive accumulator buffer empty
|
|
#
|
|
with cute.arch.elect_one():
|
|
acc_pipeline.consumer_release(acc_consumer_state)
|
|
acc_consumer_state.advance()
|
|
|
|
peek_acc_full_status = cutlass.Boolean(1)
|
|
if acc_consumer_state.count < k_tile_cnt:
|
|
peek_acc_full_status = acc_pipeline.consumer_try_wait(
|
|
acc_consumer_state
|
|
)
|
|
|
|
tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)]
|
|
tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc))
|
|
|
|
#
|
|
# Wait for epilogue buffer empty
|
|
#
|
|
epi_pipeline.producer_acquire(epi_producer_state)
|
|
|
|
# copy the accumulator to tensor memory buffer
|
|
cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc)
|
|
cute.arch.fence_view_async_tmem_store()
|
|
|
|
#
|
|
# Async arrive epilogue buffer full
|
|
#
|
|
epi_pipeline.producer_commit(epi_producer_state)
|
|
epi_producer_state.advance()
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
|
for idx in cutlass.range(4, unroll_full=True):
|
|
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
|
|
is_valid_tile = tile_info[3] == 1
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
|
tile_info_consumer_state.advance()
|
|
|
|
#
|
|
# Specialized epilogue warps
|
|
#
|
|
if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]:
|
|
cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps)
|
|
#
|
|
# Alloc tensor memory buffer
|
|
#
|
|
tmem.allocate(self.num_tmem_alloc_cols)
|
|
|
|
#
|
|
# Bar sync for retrieve tensor memory ptr from shared memory
|
|
#
|
|
tmem.wait_for_alloc()
|
|
|
|
#
|
|
# Retrieving tensor memory ptr and make accumulator tensor
|
|
#
|
|
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
|
# (MMA, MMA_M, MMA_N, STAGE)
|
|
tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
|
tCtAcc_final = cute.make_tensor(
|
|
tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout
|
|
)
|
|
|
|
#
|
|
# Partition for epilogue
|
|
#
|
|
epi_tidx = tidx % 128
|
|
(
|
|
tiled_copy_t2r,
|
|
tTR_tAcc_base,
|
|
tTR_rAcc,
|
|
) = self.epilog_tmem_copy_and_partition(
|
|
epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs
|
|
)
|
|
|
|
tTR_rC = None
|
|
tiled_copy_r2s = None
|
|
simt_atom = None
|
|
tRS_rC = None
|
|
tRS_sC = None
|
|
bSG_sC = None
|
|
bSG_gC_partitioned = None
|
|
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
|
|
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
|
|
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
|
)
|
|
(
|
|
tma_atom_c,
|
|
bSG_sC,
|
|
bSG_gC_partitioned,
|
|
) = self.epilog_gmem_copy_and_partition(
|
|
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
|
|
)
|
|
|
|
#
|
|
# Persistent tile scheduling loop
|
|
#
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
epi_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, 1
|
|
)
|
|
|
|
c_pipeline = None
|
|
# Threads/warps participating in tma store pipeline
|
|
c_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
32 * len(self.epilog_warp_id),
|
|
)
|
|
c_pipeline = pipeline.PipelineTmaStore.create(
|
|
num_stages=self.num_c_stage,
|
|
producer_group=c_producer_group,
|
|
)
|
|
|
|
tile_info_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.num_tile_stage
|
|
)
|
|
|
|
# get the first tile info
|
|
tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
|
|
# Get tile coord from tile scheduler
|
|
cur_tile_coord = work_tile.tile_idx
|
|
# initialize the tile info
|
|
tile_info[0] = cur_tile_coord[0]
|
|
tile_info[1] = cur_tile_coord[1]
|
|
tile_info[2] = cur_tile_coord[2]
|
|
tile_info[3] = work_tile.is_valid_tile
|
|
|
|
is_valid_tile = cutlass.Boolean(1)
|
|
is_valid_tile = tile_info[3] == 1
|
|
|
|
num_prev_subtiles = cutlass.Int32(0)
|
|
|
|
while is_valid_tile:
|
|
mma_tile_coord_mnl = (
|
|
tile_info[0] // cute.size(tiled_mma.thr_id.shape),
|
|
tile_info[1],
|
|
tile_info[2],
|
|
)
|
|
#
|
|
# Slice to per mma tile index
|
|
#
|
|
bSG_gC = None
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
bSG_gC = bSG_gC_partitioned[
|
|
(
|
|
None,
|
|
None,
|
|
None,
|
|
mma_tile_coord_mnl[0],
|
|
mma_tile_coord_mnl[1],
|
|
mma_tile_coord_mnl[2],
|
|
)
|
|
]
|
|
|
|
# Set tensor memory buffer for current tile
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
|
tTR_tAcc = tTR_tAcc_base[
|
|
(None, None, None, None, None, epi_consumer_state.index)
|
|
]
|
|
|
|
#
|
|
# Wait for accumulator buffer full
|
|
#
|
|
epi_pipeline.consumer_wait(epi_consumer_state)
|
|
|
|
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
|
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
|
|
|
#
|
|
# Store accumulator to global memory in subtiles
|
|
#
|
|
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
|
for subtile_idx in cutlass.range(subtile_cnt):
|
|
#
|
|
# Load accumulator from tensor memory buffer to register
|
|
#
|
|
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
|
|
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
|
|
|
#
|
|
# Convert to C type
|
|
#
|
|
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
|
acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
|
|
tRS_rC.store(acc_vec)
|
|
|
|
#
|
|
# Store C to shared memory
|
|
#
|
|
num_prev_subtiles = num_prev_subtiles + 1
|
|
c_buffer = num_prev_subtiles % self.num_c_stage
|
|
cute.copy(
|
|
tiled_copy_r2s,
|
|
tRS_rC,
|
|
tRS_sC[(None, None, None, c_buffer)],
|
|
)
|
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
self.epilog_sync_barrier.arrive_and_wait()
|
|
|
|
#
|
|
# TMA store C to global memory
|
|
#
|
|
if warp_idx == self.epilog_warp_id[0]:
|
|
cute.copy(
|
|
tma_atom_c,
|
|
bSG_sC[(None, c_buffer)],
|
|
bSG_gC[(None, subtile_idx)],
|
|
)
|
|
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
c_pipeline.producer_commit()
|
|
c_pipeline.producer_acquire()
|
|
self.epilog_sync_barrier.arrive_and_wait()
|
|
|
|
#
|
|
# Async arrive accumulator buffer empty
|
|
#
|
|
epi_pipeline.consumer_release(epi_consumer_state)
|
|
epi_consumer_state.advance()
|
|
|
|
#
|
|
# Advance to next tile
|
|
#
|
|
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
|
|
for idx in cutlass.range(4, unroll_full=True):
|
|
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
|
|
is_valid_tile = tile_info[3] == 1
|
|
cute.arch.fence_proxy(
|
|
cute.arch.ProxyKind.async_shared,
|
|
space=cute.arch.SharedSpace.shared_cta,
|
|
)
|
|
tile_info_pipeline.consumer_release(tile_info_consumer_state)
|
|
tile_info_consumer_state.advance()
|
|
|
|
#
|
|
# Dealloc the tensor memory buffer
|
|
#
|
|
tmem.relinquish_alloc_permit()
|
|
self.epilog_sync_barrier.arrive_and_wait()
|
|
tmem.free(tmem_ptr)
|
|
#
|
|
# Wait for C store complete
|
|
#
|
|
c_pipeline.producer_tail()
|
|
|
|
def acc_update_tmem_copy_and_partition(
|
|
self,
|
|
tidx: cutlass.Int32,
|
|
tAcc: cute.Tensor,
|
|
tAcc_final: cute.Tensor,
|
|
gC_mnl: cute.Tensor,
|
|
sSFA: cute.Tensor,
|
|
sSFB: cute.Tensor,
|
|
epi_tile: cute.Tile,
|
|
) -> Tuple[
|
|
cute.TiledCopy,
|
|
cute.TiledCopy,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
cute.Tensor,
|
|
]:
|
|
"""
|
|
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
|
Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination).
|
|
Partition the scale factor tensor for related copy operations.
|
|
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param tAcc: The accumulator tensor to be copied and partitioned
|
|
:type tAcc: cute.Tensor
|
|
:param tAcc_final: The final accumulator tensor to be copied and partitioned
|
|
:type tAcc_final: cute.Tensor
|
|
:param gC_mnl: The global tensor C
|
|
:type gC_mnl: cute.Tensor
|
|
:param sSFA: The scale factor tensor for A
|
|
:type sSFA: cute.Tensor
|
|
:param sSFB: The scale factor tensor for B
|
|
:type sSFB: cute.Tensor
|
|
:param epi_tile: The epilogue tiler
|
|
:type epi_tile: cute.Tile
|
|
|
|
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
|
|
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
- tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t)
|
|
- tTR_tAcc: The partitioned accumulator tensor
|
|
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
|
|
- tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results
|
|
- tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r
|
|
- tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r
|
|
- tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results
|
|
- tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t
|
|
:rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]
|
|
"""
|
|
# Make tiledCopy for tensor memory load
|
|
tmem_load_atom = None
|
|
tmem_store_atom = None
|
|
if cutlass.const_expr(self.mma_tiler[0] == 64):
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)),
|
|
self.acc_dtype,
|
|
)
|
|
elif cutlass.const_expr(self.mma_tiler[0] == 128):
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)),
|
|
self.acc_dtype,
|
|
)
|
|
else:
|
|
# default: 16dp
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(1)),
|
|
self.acc_dtype,
|
|
)
|
|
if cutlass.const_expr(self.mma_tiler[0] == 64):
|
|
tmem_store_atom = cute.make_copy_atom(
|
|
tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)),
|
|
self.acc_dtype,
|
|
)
|
|
elif cutlass.const_expr(self.mma_tiler[0] == 128):
|
|
tmem_store_atom = cute.make_copy_atom(
|
|
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)),
|
|
self.acc_dtype,
|
|
)
|
|
else:
|
|
# default: 16dp
|
|
tmem_store_atom = cute.make_copy_atom(
|
|
tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(1)),
|
|
self.acc_dtype,
|
|
)
|
|
|
|
tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
|
|
tAcc_final_epi = cute.flat_divide(
|
|
tAcc_final[((None, None), 0, 0, None)], epi_tile
|
|
)
|
|
|
|
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
|
tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)]
|
|
)
|
|
tiled_copy_r2t = tcgen05.make_tmem_copy(
|
|
tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)]
|
|
)
|
|
|
|
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
|
thr_copy_r2t = tiled_copy_r2t.get_slice(tidx)
|
|
|
|
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
|
gC_mnl_epi = cute.flat_divide(
|
|
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
|
)
|
|
sSFA_epi = cute.flat_divide(sSFA, epi_tile)
|
|
sSFB_epi = cute.flat_divide(sSFB, epi_tile)
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
|
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
|
tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi)
|
|
tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi)
|
|
# (T2R, T2R_M, T2R_N)
|
|
tTR_rAcc = cute.make_rmem_tensor(
|
|
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
|
|
)
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
|
tTR_rAcc_final_ = cute.make_rmem_tensor(
|
|
tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype
|
|
)
|
|
tTR_rAcc_final = cute.group_modes(
|
|
tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_)
|
|
)
|
|
|
|
tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi)
|
|
tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi)
|
|
# (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
|
tRT_rAcc_final_ = cute.make_rmem_tensor(
|
|
tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype
|
|
)
|
|
# (R2T, R2T_M, R2T_N, (EPI_M, EPI_N))
|
|
tRT_rAcc_final = cute.group_modes(
|
|
tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_)
|
|
)
|
|
|
|
return (
|
|
tiled_copy_t2r,
|
|
tiled_copy_r2t,
|
|
tTR_tAcc,
|
|
tTR_rAcc,
|
|
tTR_rAcc_final,
|
|
tTR_sSFA,
|
|
tTR_sSFB,
|
|
tRT_rAcc_final,
|
|
tRT_tAcc_final,
|
|
)
|
|
|
|
def epilog_tmem_copy_and_partition(
|
|
self,
|
|
tidx: cutlass.Int32,
|
|
tAcc: cute.Tensor,
|
|
gC_mnl: cute.Tensor,
|
|
epi_tile: cute.Tile,
|
|
use_2cta_instrs: Union[cutlass.Boolean, bool],
|
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
"""
|
|
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
|
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param tAcc: The accumulator tensor to be copied and partitioned
|
|
:type tAcc: cute.Tensor
|
|
:param gC_mnl: The global tensor C
|
|
:type gC_mnl: cute.Tensor
|
|
:param epi_tile: The epilogue tiler
|
|
:type epi_tile: cute.Tile
|
|
:param use_2cta_instrs: Whether use_2cta_instrs is enabled
|
|
:type use_2cta_instrs: bool
|
|
|
|
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
|
|
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
- tTR_tAcc: The partitioned accumulator tensor
|
|
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
|
|
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
"""
|
|
# Make tiledCopy for tensor memory load
|
|
copy_atom_t2r = sm100_utils.get_tmem_load_op(
|
|
self.cta_tile_shape_mnk,
|
|
self.c_layout,
|
|
self.c_dtype,
|
|
self.acc_dtype,
|
|
epi_tile,
|
|
use_2cta_instrs,
|
|
)
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
|
|
tAcc_epi = cute.flat_divide(
|
|
tAcc[((None, None), 0, 0, None)],
|
|
epi_tile,
|
|
)
|
|
# (EPI_TILE_M, EPI_TILE_N)
|
|
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
|
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
|
|
)
|
|
|
|
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
|
|
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
|
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
|
gC_mnl_epi = cute.flat_divide(
|
|
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
|
)
|
|
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
|
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
|
|
|
# (T2R, T2R_M, T2R_N)
|
|
tTR_rAcc = cute.make_rmem_tensor(
|
|
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
|
|
)
|
|
|
|
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
|
|
|
def epilog_smem_copy_and_partition(
|
|
self,
|
|
tiled_copy_t2r: cute.TiledCopy,
|
|
tTR_rC: cute.Tensor,
|
|
tidx: cutlass.Int32,
|
|
sC: cute.Tensor,
|
|
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
|
"""
|
|
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
|
|
|
|
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
|
:type tiled_copy_t2r: cute.TiledCopy
|
|
:param tTR_rC: The partitioned accumulator tensor
|
|
:type tTR_rC: cute.Tensor
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param sC: The shared memory tensor to be copied and partitioned
|
|
:type sC: cute.Tensor
|
|
:type sepi: cute.Tensor
|
|
|
|
:return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where:
|
|
- tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
|
|
- tRS_rC: The partitioned tensor C (register source)
|
|
- tRS_sC: The partitioned tensor C (smem destination)
|
|
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
|
"""
|
|
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
|
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
|
)
|
|
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
tRS_sC = thr_copy_r2s.partition_D(sC)
|
|
# (R2S, R2S_M, R2S_N)
|
|
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
|
|
return tiled_copy_r2s, tRS_rC, tRS_sC
|
|
|
|
def epilog_gmem_copy_and_partition(
|
|
self,
|
|
tidx: cutlass.Int32,
|
|
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
|
gC_mnl: cute.Tensor,
|
|
epi_tile: cute.Tile,
|
|
sC: cute.Tensor,
|
|
) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
|
|
"""Make tiledCopy for global memory store, then use it to:
|
|
- partition register array (source) and global memory (destination) for none TMA store version;
|
|
- partition shared memory (source) and global memory (destination) for TMA store version.
|
|
|
|
:param tidx: The thread index in epilogue warp groups
|
|
:type tidx: cutlass.Int32
|
|
:param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
|
|
:type atom: cute.CopyAtom or cute.TiledCopy
|
|
:param gC_mnl: The global tensor C
|
|
:type gC_mnl: cute.Tensor
|
|
:param epi_tile: The epilogue tiler
|
|
:type epi_tile: cute.Tile
|
|
:param sC: The shared memory tensor to be copied and partitioned
|
|
:type sC: cute.Tensor
|
|
|
|
:return: A tuple containing :
|
|
- For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where:
|
|
- tma_atom_c: The TMA copy atom
|
|
- bSG_sC: The partitioned shared memory tensor C
|
|
- bSG_gC: The partitioned global tensor C
|
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
|
"""
|
|
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
|
|
gC_epi = cute.flat_divide(
|
|
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
|
)
|
|
tma_atom_c = atom
|
|
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
|
|
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
|
# ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL)
|
|
bSG_sC, bSG_gC = cpasync.tma_partition(
|
|
tma_atom_c,
|
|
0,
|
|
cute.make_layout(1),
|
|
sC_for_tma_partition,
|
|
gC_for_tma_partition,
|
|
)
|
|
return tma_atom_c, bSG_sC, bSG_gC
|
|
|
|
@staticmethod
|
|
def _compute_stages(
|
|
tiled_mma: cute.TiledMma,
|
|
mma_tiler_mnk: Tuple[int, int, int],
|
|
a_dtype: Type[cutlass.Numeric],
|
|
b_dtype: Type[cutlass.Numeric],
|
|
epi_tile: cute.Tile,
|
|
c_dtype: Type[cutlass.Numeric],
|
|
c_layout: utils.LayoutEnum,
|
|
sfa_dtype: Type[cutlass.Numeric],
|
|
sfb_dtype: Type[cutlass.Numeric],
|
|
sfa_count: int,
|
|
sfb_count: int,
|
|
num_smem_capacity: int,
|
|
occupancy: int,
|
|
) -> Tuple[int, int, int]:
|
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
|
|
:param tiled_mma: The tiled MMA object defining the core computation.
|
|
:type tiled_mma: cute.TiledMma
|
|
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
|
|
:type mma_tiler_mnk: tuple[int, int, int]
|
|
:param a_dtype: Data type of operand A.
|
|
:type a_dtype: type[cutlass.Numeric]
|
|
:param b_dtype: Data type of operand B.
|
|
:type b_dtype: type[cutlass.Numeric]
|
|
:param epi_tile: The epilogue tile shape.
|
|
:type epi_tile: cute.Tile
|
|
:param c_dtype: Data type of operand C (output).
|
|
:type c_dtype: type[cutlass.Numeric]
|
|
:param c_layout: Layout of operand C.
|
|
:type c_layout: utils.LayoutEnum
|
|
:param num_smem_capacity: Total available shared memory capacity in bytes.
|
|
:type num_smem_capacity: int
|
|
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
:type occupancy: int
|
|
|
|
:return: A tuple containing the computed number of stages for:
|
|
(ACC stages, A/B operand stages, C stages)
|
|
:rtype: tuple[int, int, int]
|
|
"""
|
|
# Default ACC stages
|
|
num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6
|
|
|
|
# Default C stages
|
|
num_c_stage = 2
|
|
|
|
# Default ScaleA/B stages
|
|
num_scale_stage = 10
|
|
|
|
# Default Tile info stages
|
|
num_tile_stage = 2
|
|
|
|
# Calculate smem layout and size for one stage of A, B, and C
|
|
a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(
|
|
tiled_mma,
|
|
mma_tiler_mnk,
|
|
a_dtype,
|
|
1, # a tmp 1 stage is provided
|
|
)
|
|
b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
|
|
tiled_mma,
|
|
mma_tiler_mnk,
|
|
b_dtype,
|
|
1, # a tmp 1 stage is provided
|
|
)
|
|
c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
|
|
c_dtype,
|
|
c_layout,
|
|
epi_tile,
|
|
1,
|
|
)
|
|
|
|
ab_bytes_per_stage = cute.size_in_bytes(
|
|
a_dtype, a_smem_layout_stage_one
|
|
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
|
|
# 1024B alignment
|
|
mbar_helpers_bytes = 1024
|
|
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
|
c_bytes = c_bytes_per_stage * num_c_stage
|
|
sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage
|
|
sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage
|
|
scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024
|
|
|
|
# Calculate A/B stages:
|
|
# Start with total smem per CTA (capacity / occupancy)
|
|
# Subtract reserved bytes and initial C stages bytes
|
|
# Divide remaining by bytes needed per A/B stage
|
|
num_ab_stage = (
|
|
num_smem_capacity // occupancy
|
|
- (mbar_helpers_bytes + c_bytes + scale_bytes)
|
|
) // ab_bytes_per_stage
|
|
|
|
# Refine epilogue stages:
|
|
# Calculate remaining smem after allocating for A/B stages and reserved bytes
|
|
# Add remaining unused smem to epilogue
|
|
num_c_stage += (
|
|
num_smem_capacity
|
|
- occupancy * ab_bytes_per_stage * num_ab_stage
|
|
- occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes)
|
|
) // (occupancy * c_bytes_per_stage)
|
|
return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage
|
|
|
|
@staticmethod
|
|
def _compute_grid(
|
|
c: cute.Tensor,
|
|
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
max_active_clusters: cutlass.Constexpr,
|
|
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
|
|
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
|
|
|
|
:param c: The output tensor C
|
|
:type c: cute.Tensor
|
|
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
:type cta_tile_shape_mnk: tuple[int, int, int]
|
|
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
|
:type cluster_shape_mn: tuple[int, int]
|
|
:param max_active_clusters: Maximum number of active clusters.
|
|
:type max_active_clusters: cutlass.Constexpr
|
|
|
|
:return: A tuple containing:
|
|
- tile_sched_params: Parameters for the persistent tile scheduler.
|
|
- grid: Grid shape for kernel launch.
|
|
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
|
|
"""
|
|
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
|
|
gc = cute.zipped_divide(c, tiler=c_shape)
|
|
num_ctas_mnl = gc[(0, (None, None, None))].shape
|
|
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
|
|
|
tile_sched_params = utils.PersistentTileSchedulerParams(
|
|
num_ctas_mnl, cluster_shape_mnl
|
|
)
|
|
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
|
tile_sched_params, max_active_clusters
|
|
)
|
|
|
|
return tile_sched_params, grid
|
|
|
|
@staticmethod
|
|
def _get_tma_atom_kind(
|
|
atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean
|
|
) -> Union[
|
|
cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp
|
|
]:
|
|
"""
|
|
Select the appropriate TMA copy atom based on the number of SMs and the multicast flag.
|
|
|
|
:param atom_sm_cnt: The number of SMs
|
|
:type atom_sm_cnt: cutlass.Int32
|
|
:param mcast: The multicast flag
|
|
:type mcast: cutlass.Boolean
|
|
|
|
:return: The appropriate TMA copy atom kind
|
|
:rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp
|
|
|
|
:raise ValueError: If the atom_sm_cnt is invalid
|
|
"""
|
|
if atom_sm_cnt == 2 and mcast:
|
|
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
|
|
elif atom_sm_cnt == 2 and not mcast:
|
|
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO)
|
|
elif atom_sm_cnt == 1 and mcast:
|
|
return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE)
|
|
elif atom_sm_cnt == 1 and not mcast:
|
|
return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
|
|
|
raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}")
|
|
|
|
@staticmethod
|
|
def is_valid_dtypes(
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
) -> bool:
|
|
"""
|
|
Check if the dtypes are valid
|
|
|
|
:param ab_dtype: The data type of the A and B operands
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: The data type of the accumulator
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: The data type of the output tensor
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
|
|
:return: True if the dtypes are valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
is_valid = True
|
|
if ab_dtype not in {
|
|
cutlass.Float8E4M3FN,
|
|
cutlass.Float8E5M2,
|
|
}:
|
|
is_valid = False
|
|
if acc_dtype not in {cutlass.Float32}:
|
|
is_valid = False
|
|
if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}:
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def is_valid_mma_tiler_and_cluster_shape(
|
|
use_2cta_instrs: bool,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
) -> bool:
|
|
"""
|
|
Check if the mma tiler and cluster shape are valid
|
|
|
|
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
:type use_2cta_instrs: bool
|
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
|
|
:return: True if the mma tiler and cluster shape are valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
is_valid = True
|
|
# Skip invalid mma tile shape
|
|
if not (
|
|
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
|
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
|
):
|
|
is_valid = False
|
|
# Skip invalid mma tile n
|
|
if mma_tiler_mn[1] not in (128,):
|
|
is_valid = False
|
|
# Skip illegal cluster shape
|
|
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
|
is_valid = False
|
|
# Skip invalid cluster shape
|
|
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
|
|
if (
|
|
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
|
or cluster_shape_mn[0] <= 0
|
|
or cluster_shape_mn[1] <= 0
|
|
or not is_power_of_2(cluster_shape_mn[0])
|
|
or not is_power_of_2(cluster_shape_mn[1])
|
|
):
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def is_valid_tensor_alignment(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
l: int,
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
) -> bool:
|
|
"""
|
|
Check if the tensor alignment is valid
|
|
|
|
:param m: The number of rows in the A tensor
|
|
:type m: int
|
|
:param n: The number of columns in the B tensor
|
|
:type n: int
|
|
:param k: The number of columns in the A tensor
|
|
:type k: int
|
|
:param l: The number of columns in the C tensor
|
|
:type l: int
|
|
:param ab_dtype: The data type of the A and B operands
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: The data type of the output tensor
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
:param a_major: The major axis of the A tensor
|
|
:type a_major: str
|
|
:param b_major: The major axis of the B tensor
|
|
:type b_major: str
|
|
:param c_major: The major axis of the C tensor
|
|
:type c_major: str
|
|
|
|
:return: True if the problem shape is valid, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
is_valid = True
|
|
|
|
def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
|
|
major_mode_idx = 0 if is_mode0_major else 1
|
|
num_major_elements = tensor_shape[major_mode_idx]
|
|
num_contiguous_elements = 16 * 8 // dtype.width
|
|
return num_major_elements % num_contiguous_elements == 0
|
|
|
|
if (
|
|
not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
|
|
or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
|
|
or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
|
|
):
|
|
is_valid = False
|
|
return is_valid
|
|
|
|
@staticmethod
|
|
def can_implement(
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
use_2cta_instrs: bool,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
l: int,
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
) -> bool:
|
|
"""
|
|
Check if the gemm can be implemented
|
|
|
|
:param ab_dtype: The data type of the A and B operands
|
|
:type ab_dtype: Type[cutlass.Numeric]
|
|
:param acc_dtype: The data type of the accumulator
|
|
:type acc_dtype: Type[cutlass.Numeric]
|
|
:param c_dtype: The data type of the output tensor
|
|
:type c_dtype: Type[cutlass.Numeric]
|
|
:param use_2cta_instrs: Whether to use 2 CTA groups
|
|
:type use_2cta_instrs: bool
|
|
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
|
:type mma_tiler_mn: Tuple[int, int]
|
|
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
|
:type cluster_shape_mn: Tuple[int, int]
|
|
:param m: The number of rows in the A tensor
|
|
:type m: int
|
|
:param n: The number of columns in the B tensor
|
|
:type n: int
|
|
:param k: The number of columns in the A tensor
|
|
:type k: int
|
|
:param l: The number of columns in the C tensor
|
|
:type l: int
|
|
:param a_major: The major axis of the A tensor
|
|
:type a_major: str
|
|
:param b_major: The major axis of the B tensor
|
|
:type b_major: str
|
|
:param c_major: The major axis of the C tensor
|
|
:type c_major: str
|
|
|
|
:return: True if the gemm can be implemented, False otherwise
|
|
:rtype: bool
|
|
"""
|
|
can_implement = True
|
|
# Skip unsupported types
|
|
if not BlockwiseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype):
|
|
can_implement = False
|
|
# Skip invalid mma tile shape and cluster shape
|
|
if not BlockwiseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
|
|
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
|
|
):
|
|
can_implement = False
|
|
# Skip illegal problem shape for load/store alignment
|
|
if not BlockwiseGemmKernel.is_valid_tensor_alignment(
|
|
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
|
|
):
|
|
can_implement = False
|
|
# Skip unsupported A/B layout
|
|
if not (a_major == "k" and b_major == "k"):
|
|
can_implement = False
|
|
return can_implement
|
|
|
|
|
|
def create_tensors(
|
|
l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, scale_dtype
|
|
):
|
|
torch.manual_seed(1111)
|
|
|
|
a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype)
|
|
b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype)
|
|
c_torch_cpu = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype)
|
|
sfa_torch_cpu = cutlass_torch.matrix(l, m, math.ceil(k / 128), True, scale_dtype)
|
|
sfb_torch_cpu = cutlass_torch.matrix(
|
|
l, math.ceil(n / 128), math.ceil(k / 128), False, scale_dtype
|
|
)
|
|
|
|
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
b_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like(
|
|
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
sfa_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
sfa_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
sfb_tensor, _ = cutlass_torch.cute_tensor_like(
|
|
sfb_torch_cpu, scale_dtype, is_dynamic_layout=True, assumed_align=16
|
|
)
|
|
|
|
return (
|
|
a_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
sfa_tensor,
|
|
sfb_tensor,
|
|
a_torch_cpu,
|
|
b_torch_cpu,
|
|
c_torch_cpu,
|
|
sfa_torch_cpu,
|
|
sfb_torch_cpu,
|
|
c_torch_gpu,
|
|
)
|
|
|
|
|
|
def run(
|
|
mnkl: Tuple[int, int, int, int],
|
|
ab_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
scale_dtype: Type[cutlass.Numeric],
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
use_2cta_instrs: bool,
|
|
tolerance: float,
|
|
warmup_iterations: int = 0,
|
|
iterations: int = 1,
|
|
skip_ref_check: bool = False,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
|
|
"""
|
|
print("Running Blackwell Persistent Dense Blockwise GEMM test with:")
|
|
print(f"mnkl: {mnkl}")
|
|
print(
|
|
f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}, Scale dtype: {scale_dtype}"
|
|
)
|
|
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
|
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
|
print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}")
|
|
print(f"Use TMA Store: {'True'}")
|
|
print(f"Tolerance: {tolerance}")
|
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
print(f"Iterations: {iterations}")
|
|
print(f"Skip reference checking: {skip_ref_check}")
|
|
|
|
# Unpack parameters
|
|
m, n, k, l = mnkl
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
if not BlockwiseGemmKernel.can_implement(
|
|
ab_dtype,
|
|
acc_dtype,
|
|
c_dtype,
|
|
use_2cta_instrs,
|
|
mma_tiler_mn,
|
|
cluster_shape_mn,
|
|
m,
|
|
n,
|
|
k,
|
|
l,
|
|
a_major,
|
|
b_major,
|
|
c_major,
|
|
):
|
|
raise TypeError(
|
|
f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
|
|
)
|
|
|
|
(
|
|
a_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
sfa_tensor,
|
|
sfb_tensor,
|
|
a_torch_cpu,
|
|
b_torch_cpu,
|
|
c_torch_cpu,
|
|
sfa_torch_cpu,
|
|
sfb_torch_cpu,
|
|
c_torch_gpu,
|
|
) = create_tensors(
|
|
l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, scale_dtype
|
|
)
|
|
# Configure gemm kernel
|
|
gemm = BlockwiseGemmKernel(
|
|
acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
|
|
)
|
|
|
|
# Compute max active clusters on current device
|
|
hardware_info = cutlass.utils.HardwareInfo()
|
|
max_active_clusters = hardware_info.get_max_active_clusters(
|
|
cluster_shape_mn[0] * cluster_shape_mn[1]
|
|
)
|
|
|
|
# Get current CUDA stream from PyTorch
|
|
torch_stream = torch.cuda.current_stream()
|
|
# Get the raw stream pointer as a CUstream
|
|
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
|
# Compile gemm kernel
|
|
compiled_gemm = cute.compile(
|
|
gemm,
|
|
a_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
sfa_tensor,
|
|
sfb_tensor,
|
|
max_active_clusters,
|
|
current_stream,
|
|
)
|
|
|
|
# Execution
|
|
compiled_gemm(
|
|
a_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
sfa_tensor,
|
|
sfb_tensor,
|
|
current_stream,
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
# Compute reference result
|
|
if not skip_ref_check:
|
|
# update
|
|
def pad_and_multiply(scale, tensor):
|
|
cm, ck, _ = scale.shape
|
|
m, k, _ = tensor.shape
|
|
IsGroupWise = False
|
|
IsBlockWise = False
|
|
if ck == math.ceil(k / 128):
|
|
IsGroupWise = True
|
|
if cm == math.ceil(m / 128):
|
|
IsBlockWise = True
|
|
if not IsBlockWise and not IsGroupWise:
|
|
raise ValueError("Only support granularity = 128")
|
|
|
|
k_idx = torch.arange(k, device=scale.device)
|
|
if IsGroupWise:
|
|
k_idx = k_idx // 128
|
|
m_idx = torch.arange(m, device=scale.device)
|
|
if IsBlockWise:
|
|
m_idx = m_idx // 128
|
|
expanded_scale = scale[m_idx[:, None], k_idx, :]
|
|
|
|
result = expanded_scale * tensor
|
|
|
|
return result
|
|
|
|
updated_a = pad_and_multiply(sfa_torch_cpu, a_torch_cpu)
|
|
updated_b = pad_and_multiply(sfb_torch_cpu, b_torch_cpu)
|
|
|
|
ref = torch.einsum("mkl,nkl->mnl", updated_a, updated_b).to(
|
|
cutlass_torch.dtype(c_dtype)
|
|
)
|
|
res = c_torch_gpu.view(cutlass_torch.dtype(c_dtype))
|
|
|
|
torch.testing.assert_close(res.cpu(), ref.cpu(), atol=tolerance, rtol=1e-03)
|
|
|
|
def generate_tensors():
|
|
# Reuse existing CPU reference tensors and create new GPU tensors from them
|
|
(
|
|
a_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
sfa_tensor,
|
|
sfb_tensor,
|
|
a_torch_cpu,
|
|
b_torch_cpu,
|
|
c_torch_cpu,
|
|
sfa_torch_cpu,
|
|
sfb_torch_cpu,
|
|
c_torch_gpu,
|
|
) = create_tensors(
|
|
l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, scale_dtype
|
|
)
|
|
return testing.JitArguments(
|
|
a_tensor,
|
|
b_tensor,
|
|
c_tensor,
|
|
sfa_tensor,
|
|
sfb_tensor,
|
|
current_stream,
|
|
)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
a_torch_cpu.numel() * a_torch_cpu.element_size()
|
|
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
|
|
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
|
|
+ sfa_torch_cpu.numel() * sfa_torch_cpu.element_size()
|
|
+ sfb_torch_cpu.numel() * sfb_torch_cpu.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
exec_time = testing.benchmark(
|
|
compiled_gemm,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=current_stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
return exec_time # Return execution time in microseconds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
|
try:
|
|
return tuple(int(x.strip()) for x in s.split(","))
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Example of Dense Persistent GEMM on Blackwell."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mnkl",
|
|
type=parse_comma_separated_ints,
|
|
default=(256, 256, 512, 1),
|
|
help="mnkl dimensions (comma-separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--mma_tiler_mn",
|
|
type=parse_comma_separated_ints,
|
|
default=(128, 128),
|
|
help="Mma tile shape (comma-separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--cluster_shape_mn",
|
|
type=parse_comma_separated_ints,
|
|
default=(1, 1),
|
|
help="Cluster shape (comma-separated)",
|
|
)
|
|
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN)
|
|
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
|
|
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
|
|
parser.add_argument("--scale_dtype", type=cutlass.dtype, default=cutlass.Float32)
|
|
parser.add_argument(
|
|
"--use_2cta_instrs",
|
|
action="store_true",
|
|
help="Enable 2CTA MMA instructions feature",
|
|
)
|
|
parser.add_argument("--a_major", choices=["k"], type=str, default="k")
|
|
parser.add_argument("--b_major", choices=["k"], type=str, default="k")
|
|
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
|
)
|
|
parser.add_argument(
|
|
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
|
)
|
|
parser.add_argument(
|
|
"--iterations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of iterations to run the kernel",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
|
)
|
|
parser.add_argument(
|
|
"--use_cold_l2", action="store_true", default=False, help="Use cold L2"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if len(args.mnkl) != 4:
|
|
parser.error("--mnkl must contain exactly 4 values")
|
|
|
|
if len(args.mma_tiler_mn) != 2:
|
|
parser.error("--mma_tiler_mn must contain exactly 2 values")
|
|
|
|
if len(args.cluster_shape_mn) != 2:
|
|
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
|
|
|
run(
|
|
args.mnkl,
|
|
args.ab_dtype,
|
|
args.c_dtype,
|
|
args.acc_dtype,
|
|
args.scale_dtype,
|
|
args.a_major,
|
|
args.b_major,
|
|
args.c_major,
|
|
args.mma_tiler_mn,
|
|
args.cluster_shape_mn,
|
|
args.use_2cta_instrs,
|
|
args.tolerance,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
print("PASS")
|