v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -44,7 +44,7 @@ import cutlass.utils.hopper_helpers as sm90_utils
"""
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
using CUTE DSL.
using CuTe DSL.
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
@ -70,7 +70,7 @@ To run this example:
.. code-block:: bash
python examples/hopper/dense_gemm.py \
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
--c_dtype Float16 --acc_dtype Float32 \
--a_major k --b_major k --c_major n
@ -85,7 +85,7 @@ To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/hopper/dense_gemm.py \
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
--c_dtype Float16 --acc_dtype Float32 \
--a_major k --b_major k --c_major n
@ -95,14 +95,11 @@ Constraints:
* For fp16 types, A and B must have the same data type
* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
* Fp8 types only support k-major layout
* Only fp32 accumulation is supported in this example
* CTA tile shape M must be 64/128
* CTA tile shape N must be 64/128/256
* CTA tile shape K must be 64
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
* OOB tiles are not allowed when TMA store is disabled
"""
@ -128,10 +125,10 @@ def parse_arguments() -> argparse.Namespace:
help="mnkl dimensions (comma-separated)",
)
parser.add_argument(
"--tile_shape_mnk",
"--tile_shape_mn",
type=parse_comma_separated_ints,
choices=[(128, 128, 64), (128, 256, 64), (128, 64, 64), (64, 64, 64)],
default=(128, 128, 64),
choices=[(128, 128), (128, 256), (128, 64), (64, 64)],
default=(128, 128),
help="Cta tile shape (comma-separated)",
)
parser.add_argument(
@ -190,8 +187,8 @@ def parse_arguments() -> argparse.Namespace:
if len(args.mnkl) != 4:
parser.error("--mnkl must contain exactly 4 values")
if len(args.tile_shape_mnk) != 3:
parser.error("--tile_shape_mnk must contain exactly 3 values")
if len(args.tile_shape_mn) != 2:
parser.error("--tile_shape_mn must contain exactly 2 values")
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
@ -210,10 +207,10 @@ class HopperWgmmaGemmKernel:
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
:type cluster_shape_mnk: Tuple[int, int, int]
:param tile_shape_mn: Shape of the CTA tile (M,N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
:note: Data type requirements:
- For 16-bit types: A and B must have the same data type
@ -236,8 +233,8 @@ class HopperWgmmaGemmKernel:
Example:
>>> gemm = HopperWgmmaGemmKernel(
... acc_dtype=cutlass.Float32,
... tile_shape_mnk=(128, 256, 64),
... cluster_shape_mnk=(1, 1, 1)
... tile_shape_mn=(128, 256),
... cluster_shape_mn=(1, 1)
... )
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
"""
@ -245,8 +242,8 @@ class HopperWgmmaGemmKernel:
def __init__(
self,
acc_dtype: type[cutlass.Numeric],
tile_shape_mnk: tuple[int, int, int],
cluster_shape_mnk: tuple[int, int, int],
tile_shape_mn: tuple[int, int],
cluster_shape_mn: tuple[int, int],
):
"""
Initializes the configuration for a Hopper dense GEMM kernel.
@ -256,28 +253,30 @@ class HopperWgmmaGemmKernel:
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
:type cluster_shape_mnk: Tuple[int, int, int]
:param tile_shape_mn: Shape of the CTA tile (M,N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
"""
self.acc_dtype = acc_dtype
self.cluster_shape_mnk = cluster_shape_mnk
self.cluster_shape_mn = cluster_shape_mn
self.mma_inst_shape_mn = None
self.tile_shape_mnk = tuple(tile_shape_mnk)
# K dimension is deferred in _setup_attributes
self.tile_shape_mnk = (*tile_shape_mn, 1)
# For large tile size, using two warp groups is preferred because using only one warp
# group may result in register spill
self.atom_layout_mnk = (
(2, 1, 1)
if tile_shape_mnk[0] > 64 and tile_shape_mnk[1] > 128
if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128
else (1, 1, 1)
)
self.num_mcast_ctas_a = None
self.num_mcast_ctas_b = None
self.is_a_mcast = False
self.is_b_mcast = False
self.tiled_mma = None
self.occupancy = 1
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
@ -315,12 +314,27 @@ class HopperWgmmaGemmKernel:
raise ValueError("CTA tile shape M must be 64/128")
if self.tile_shape_mnk[1] not in [64, 128, 256]:
raise ValueError("CTA tile shape N must be 64/128/256")
if self.tile_shape_mnk[2] not in [64]:
raise ValueError("CTA tile shape K must be 64")
self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_layout.sm90_mma_major_mode(),
self.b_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.tile_shape_mnk[1]),
)
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.tile_shape_mnk = (
self.tile_shape_mnk[0],
self.tile_shape_mnk[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1))
self.num_mcast_ctas_a = self.cluster_shape_mn[1]
self.num_mcast_ctas_b = self.cluster_shape_mn[0]
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
@ -401,28 +415,18 @@ class HopperWgmmaGemmKernel:
self._setup_attributes()
tiled_mma = sm90_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_layout.sm90_mma_major_mode(),
self.b_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.tile_shape_mnk[1]),
)
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
a,
self.a_smem_layout_staged,
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[1],
self.cluster_shape_mn[1],
)
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
b,
self.b_smem_layout_staged,
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[0],
self.cluster_shape_mn[0],
)
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
@ -431,20 +435,20 @@ class HopperWgmmaGemmKernel:
self.epi_tile,
)
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mnk)
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn)
@cute.struct
class SharedStorage:
mainloop_pipeline_array_ptr: cute.struct.MemRange[
cutlass.Int64, self.ab_stage * 2
]
sa: cute.struct.Align[
sA: cute.struct.Align[
cute.struct.MemRange[
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
],
self.buffer_align_bytes,
]
sb: cute.struct.Align[
sB: cute.struct.Align[
cute.struct.MemRange[
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
],
@ -461,7 +465,7 @@ class HopperWgmmaGemmKernel:
tma_tensor_b,
tma_atom_c,
tma_tensor_c,
tiled_mma,
self.tiled_mma,
self.cta_layout_mnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
@ -469,8 +473,7 @@ class HopperWgmmaGemmKernel:
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=self.cluster_shape_mnk,
smem=self.shared_storage.size_in_bytes(),
cluster=(*self.cluster_shape_mn, 1),
stream=stream,
)
return
@ -562,8 +565,8 @@ class HopperWgmmaGemmKernel:
# Get the pid from cluster id
bidx_in_cluster = cute.arch.block_in_cluster_idx()
pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0]
pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1]
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
cta_rank_in_cluster = cute.arch.make_warp_uniform(
@ -621,22 +624,22 @@ class HopperWgmmaGemmKernel:
)
# Cluster arrive after barrier init
if cute.size(self.cluster_shape_mnk) > 1:
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_arrive_relaxed()
# ///////////////////////////////////////////////////////////////////////////////
# Generate smem tensor A/B
# ///////////////////////////////////////////////////////////////////////////////
sa = storage.sa.get_tensor(
sA = storage.sA.get_tensor(
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
)
sb = storage.sb.get_tensor(
sB = storage.sB.get_tensor(
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
)
sc_ptr = cute.recast_ptr(
sa.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
sC_ptr = cute.recast_ptr(
sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
)
sc = cute.make_tensor(sc_ptr, epi_smem_layout_staged.outer)
sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer)
# ///////////////////////////////////////////////////////////////////////////////
# Local_tile partition global tensors
@ -673,34 +676,34 @@ class HopperWgmmaGemmKernel:
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
a_cta_crd = cluster_coord_mnk[1]
sa_for_tma_partition = cute.group_modes(sa, 0, 2)
sA_for_tma_partition = cute.group_modes(sA, 0, 2)
gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2)
tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
a_cta_crd,
a_cta_layout,
sa_for_tma_partition,
sA_for_tma_partition,
gA_for_tma_partition,
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
b_cta_crd = cluster_coord_mnk[0]
sb_for_tma_partition = cute.group_modes(sb, 0, 2)
sB_for_tma_partition = cute.group_modes(sB, 0, 2)
gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2)
tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
b_cta_crd,
b_cta_layout,
sb_for_tma_partition,
sB_for_tma_partition,
gB_for_tma_partition,
)
# //////////////////////////////////////////////////////////////////////////////
# Make frangments
# Make fragments
# //////////////////////////////////////////////////////////////////////////////
tCsA = thr_mma.partition_A(sa)
tCsB = thr_mma.partition_B(sb)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCrA = tiled_mma.make_fragment_A(tCsA)
tCrB = tiled_mma.make_fragment_B(tCsB)
@ -711,7 +714,7 @@ class HopperWgmmaGemmKernel:
# Cluster wait
# ///////////////////////////////////////////////////////////////////////////////
# cluster wait for barrier init
if cute.size(self.cluster_shape_mnk) > 1:
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_wait()
else:
cute.arch.sync_threads()
@ -788,7 +791,7 @@ class HopperWgmmaGemmKernel:
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_tile in range(k_pipe_mmas):
for k_tile in cutlass.range_constexpr(k_pipe_mmas):
# Wait for A/B buffer to be ready
mainloop_pipeline.consumer_wait(
mainloop_consumer_read_state, peek_ab_full_status
@ -917,7 +920,7 @@ class HopperWgmmaGemmKernel:
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.wait_group(0)
if cute.size(self.cluster_shape_mnk) > 1:
if cute.size(self.cluster_shape_mn) > 1:
# Wait for all threads in the cluster to finish, avoid early release of smem
cute.arch.cluster_arrive()
cute.arch.cluster_wait()
@ -950,33 +953,45 @@ class HopperWgmmaGemmKernel:
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sD = thr_copy_r2s.partition_D(sc)
tRS_sD = thr_copy_r2s.partition_D(sC)
# (R2S, R2S_M, R2S_N)
tRS_rAcc = tiled_copy_r2s.retile(accumulators)
# Allocate D registers.
rD_shape = cute.shape(thr_copy_r2s.partition_S(sc))
rD_shape = cute.shape(thr_copy_r2s.partition_S(sC))
tRS_rD_layout = cute.make_layout(rD_shape[:3])
tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype)
size_tRS_rD = cute.size(tRS_rD)
sepi_for_tma_partition = cute.group_modes(sc, 0, 2)
tcgc_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
sepi_for_tma_partition = cute.group_modes(sC, 0, 2)
tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
sepi_for_tma_partition,
tcgc_for_tma_partition,
tCgC_for_tma_partition,
)
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
epi_tile_shape = tcgc_for_tma_partition.shape[1]
epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1])
epi_tile_shape = tCgC_for_tma_partition.shape[1]
epi_tile_layout = cute.make_layout(
epi_tile_shape, stride=(epi_tile_shape[1], 1)
)
for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num):
# Initialize tma store c_pipeline
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.epi_stage,
producer_group=c_producer_group,
)
for epi_idx in cutlass.range_constexpr(epi_tile_num):
# Copy from accumulators to D registers
for epi_v in range(size_tRS_rD):
for epi_v in cutlass.range_constexpr(size_tRS_rD):
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
# Type conversion
@ -997,10 +1012,6 @@ class HopperWgmmaGemmKernel:
# barrier for sync
cute.arch.barrier()
# Get the global memory coordinate for the current epi tile.
epi_tile_layout = cute.make_layout(
epi_tile_shape, stride=(epi_tile_shape[1], 1)
)
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
# Copy from shared memory to global memory
if warp_idx == 0:
@ -1009,11 +1020,14 @@ class HopperWgmmaGemmKernel:
bSG_sD[(None, epi_buffer)],
bSG_gD[(None, gmem_coord)],
)
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
cute.arch.barrier()
if warp_idx == 0:
c_pipeline.producer_tail()
return
@staticmethod
@ -1055,9 +1069,7 @@ class HopperWgmmaGemmKernel:
mbar_helpers_bytes = 1024
ab_stage = (
(smem_capacity - occupancy * 1024) // occupancy
- mbar_helpers_bytes
- epi_bytes
smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
) // ab_bytes_per_stage
return ab_stage, epi_stage
@ -1195,7 +1207,7 @@ class HopperWgmmaGemmKernel:
def _compute_grid(
c: cute.Tensor,
tile_shape_mnk: tuple[int, int, int],
cluster_shape_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
) -> tuple[int, int, int]:
"""Compute grid shape for the output tensor C.
@ -1203,8 +1215,8 @@ class HopperWgmmaGemmKernel:
:type c: cute.Tensor
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type tile_shape_mnk: tuple[int, int, int]
:param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
:type cluster_shape_mnk: tuple[int, int, int]
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
:type cluster_shape_mn: tuple[int, int]
:return: Grid shape for kernel launch.
:rtype: tuple[int, int, int]
@ -1212,8 +1224,9 @@ class HopperWgmmaGemmKernel:
c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
gc = cute.zipped_divide(c, tiler=c_shape)
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
cluster_shape_mnl = (*cluster_shape_mn, 1)
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl)
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl))
return grid
@staticmethod
@ -1363,7 +1376,7 @@ def run(
a_major: str,
b_major: str,
c_major: str,
tile_shape_mnk: Tuple[int, int, int],
tile_shape_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
tolerance: float,
warmup_iterations: int,
@ -1387,8 +1400,8 @@ def run(
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param tile_shape_mnk: CTA tile shape (M, N, K)
:type tile_shape_mnk: Tuple[int, int, int]
:param tile_shape_mn: CTA tile shape (M, N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster shape (M, N)
:type cluster_shape_mn: Tuple[int, int]
:param tolerance: Tolerance value for reference validation comparison
@ -1411,7 +1424,7 @@ def run(
f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
@ -1420,7 +1433,6 @@ def run(
# Unpack parameters
m, n, k, l = mnkl
cluster_shape_mnk = (*cluster_shape_mn, 1)
# Skip unsupported types
if not HopperWgmmaGemmKernel.is_valid_dtypes(
@ -1488,7 +1500,7 @@ def run(
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn)
torch_stream = torch.cuda.Stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
@ -1572,7 +1584,7 @@ if __name__ == "__main__":
args.a_major,
args.b_major,
args.c_major,
args.tile_shape_mnk,
args.tile_shape_mn,
args.cluster_shape_mn,
args.tolerance,
args.warmup_iterations,