v4.2 tag release. (#2638)
This commit is contained in:
@ -44,7 +44,7 @@ import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
"""
|
||||
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
|
||||
using CUTE DSL.
|
||||
using CuTe DSL.
|
||||
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
||||
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
||||
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
||||
@ -70,7 +70,7 @@ To run this example:
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/hopper/dense_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
|
||||
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major k --b_major k --c_major n
|
||||
@ -85,7 +85,7 @@ To collect performance with NCU profiler:
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/hopper/dense_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
||||
--mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \
|
||||
--cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major k --b_major k --c_major n
|
||||
@ -95,14 +95,11 @@ Constraints:
|
||||
* For fp16 types, A and B must have the same data type
|
||||
* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit
|
||||
* Fp8 types only support k-major layout
|
||||
* Only fp32 accumulation is supported in this example
|
||||
* CTA tile shape M must be 64/128
|
||||
* CTA tile shape N must be 64/128/256
|
||||
* CTA tile shape K must be 64
|
||||
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
||||
i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively.
|
||||
* OOB tiles are not allowed when TMA store is disabled
|
||||
"""
|
||||
|
||||
|
||||
@ -128,10 +125,10 @@ def parse_arguments() -> argparse.Namespace:
|
||||
help="mnkl dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_shape_mnk",
|
||||
"--tile_shape_mn",
|
||||
type=parse_comma_separated_ints,
|
||||
choices=[(128, 128, 64), (128, 256, 64), (128, 64, 64), (64, 64, 64)],
|
||||
default=(128, 128, 64),
|
||||
choices=[(128, 128), (128, 256), (128, 64), (64, 64)],
|
||||
default=(128, 128),
|
||||
help="Cta tile shape (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -190,8 +187,8 @@ def parse_arguments() -> argparse.Namespace:
|
||||
|
||||
if len(args.mnkl) != 4:
|
||||
parser.error("--mnkl must contain exactly 4 values")
|
||||
if len(args.tile_shape_mnk) != 3:
|
||||
parser.error("--tile_shape_mnk must contain exactly 3 values")
|
||||
if len(args.tile_shape_mn) != 2:
|
||||
parser.error("--tile_shape_mn must contain exactly 2 values")
|
||||
if len(args.cluster_shape_mn) != 2:
|
||||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||||
|
||||
@ -210,10 +207,10 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
:param acc_dtype: Data type for accumulation during computation
|
||||
:type acc_dtype: type[cutlass.Numeric]
|
||||
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
||||
:type cluster_shape_mnk: Tuple[int, int, int]
|
||||
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
||||
:type tile_shape_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
|
||||
:note: Data type requirements:
|
||||
- For 16-bit types: A and B must have the same data type
|
||||
@ -236,8 +233,8 @@ class HopperWgmmaGemmKernel:
|
||||
Example:
|
||||
>>> gemm = HopperWgmmaGemmKernel(
|
||||
... acc_dtype=cutlass.Float32,
|
||||
... tile_shape_mnk=(128, 256, 64),
|
||||
... cluster_shape_mnk=(1, 1, 1)
|
||||
... tile_shape_mn=(128, 256),
|
||||
... cluster_shape_mn=(1, 1)
|
||||
... )
|
||||
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
|
||||
"""
|
||||
@ -245,8 +242,8 @@ class HopperWgmmaGemmKernel:
|
||||
def __init__(
|
||||
self,
|
||||
acc_dtype: type[cutlass.Numeric],
|
||||
tile_shape_mnk: tuple[int, int, int],
|
||||
cluster_shape_mnk: tuple[int, int, int],
|
||||
tile_shape_mn: tuple[int, int],
|
||||
cluster_shape_mn: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Initializes the configuration for a Hopper dense GEMM kernel.
|
||||
@ -256,28 +253,30 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
:param acc_dtype: Data type for accumulation during computation
|
||||
:type acc_dtype: type[cutlass.Numeric]
|
||||
:param tile_shape_mnk: Shape of the CTA tile (M,N,K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing
|
||||
:type cluster_shape_mnk: Tuple[int, int, int]
|
||||
:param tile_shape_mn: Shape of the CTA tile (M,N)
|
||||
:type tile_shape_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
"""
|
||||
|
||||
self.acc_dtype = acc_dtype
|
||||
|
||||
self.cluster_shape_mnk = cluster_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mn
|
||||
self.mma_inst_shape_mn = None
|
||||
self.tile_shape_mnk = tuple(tile_shape_mnk)
|
||||
# K dimension is deferred in _setup_attributes
|
||||
self.tile_shape_mnk = (*tile_shape_mn, 1)
|
||||
# For large tile size, using two warp groups is preferred because using only one warp
|
||||
# group may result in register spill
|
||||
self.atom_layout_mnk = (
|
||||
(2, 1, 1)
|
||||
if tile_shape_mnk[0] > 64 and tile_shape_mnk[1] > 128
|
||||
if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128
|
||||
else (1, 1, 1)
|
||||
)
|
||||
self.num_mcast_ctas_a = None
|
||||
self.num_mcast_ctas_b = None
|
||||
self.is_a_mcast = False
|
||||
self.is_b_mcast = False
|
||||
self.tiled_mma = None
|
||||
|
||||
self.occupancy = 1
|
||||
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
|
||||
@ -315,12 +314,27 @@ class HopperWgmmaGemmKernel:
|
||||
raise ValueError("CTA tile shape M must be 64/128")
|
||||
if self.tile_shape_mnk[1] not in [64, 128, 256]:
|
||||
raise ValueError("CTA tile shape N must be 64/128/256")
|
||||
if self.tile_shape_mnk[2] not in [64]:
|
||||
raise ValueError("CTA tile shape K must be 64")
|
||||
|
||||
self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
||||
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
||||
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
||||
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
||||
self.a_dtype,
|
||||
self.b_dtype,
|
||||
self.a_layout.sm90_mma_major_mode(),
|
||||
self.b_layout.sm90_mma_major_mode(),
|
||||
self.acc_dtype,
|
||||
self.atom_layout_mnk,
|
||||
tiler_mn=(64, self.tile_shape_mnk[1]),
|
||||
)
|
||||
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
|
||||
mma_inst_tile_k = 4
|
||||
self.tile_shape_mnk = (
|
||||
self.tile_shape_mnk[0],
|
||||
self.tile_shape_mnk[1],
|
||||
mma_inst_shape_k * mma_inst_tile_k,
|
||||
)
|
||||
|
||||
self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1))
|
||||
self.num_mcast_ctas_a = self.cluster_shape_mn[1]
|
||||
self.num_mcast_ctas_b = self.cluster_shape_mn[0]
|
||||
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
||||
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
||||
|
||||
@ -401,28 +415,18 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
self._setup_attributes()
|
||||
|
||||
tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
||||
self.a_dtype,
|
||||
self.b_dtype,
|
||||
self.a_layout.sm90_mma_major_mode(),
|
||||
self.b_layout.sm90_mma_major_mode(),
|
||||
self.acc_dtype,
|
||||
self.atom_layout_mnk,
|
||||
tiler_mn=(64, self.tile_shape_mnk[1]),
|
||||
)
|
||||
|
||||
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
||||
a,
|
||||
self.a_smem_layout_staged,
|
||||
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
||||
self.cluster_shape_mnk[1],
|
||||
self.cluster_shape_mn[1],
|
||||
)
|
||||
|
||||
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
||||
b,
|
||||
self.b_smem_layout_staged,
|
||||
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
|
||||
self.cluster_shape_mnk[0],
|
||||
self.cluster_shape_mn[0],
|
||||
)
|
||||
|
||||
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
|
||||
@ -431,20 +435,20 @@ class HopperWgmmaGemmKernel:
|
||||
self.epi_tile,
|
||||
)
|
||||
|
||||
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mnk)
|
||||
grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn)
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
mainloop_pipeline_array_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.ab_stage * 2
|
||||
]
|
||||
sa: cute.struct.Align[
|
||||
sA: cute.struct.Align[
|
||||
cute.struct.MemRange[
|
||||
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
|
||||
],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
sb: cute.struct.Align[
|
||||
sB: cute.struct.Align[
|
||||
cute.struct.MemRange[
|
||||
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
|
||||
],
|
||||
@ -461,7 +465,7 @@ class HopperWgmmaGemmKernel:
|
||||
tma_tensor_b,
|
||||
tma_atom_c,
|
||||
tma_tensor_c,
|
||||
tiled_mma,
|
||||
self.tiled_mma,
|
||||
self.cta_layout_mnk,
|
||||
self.a_smem_layout_staged,
|
||||
self.b_smem_layout_staged,
|
||||
@ -469,8 +473,7 @@ class HopperWgmmaGemmKernel:
|
||||
).launch(
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=self.cluster_shape_mnk,
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -562,8 +565,8 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
# Get the pid from cluster id
|
||||
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
||||
pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0]
|
||||
pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1]
|
||||
pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0]
|
||||
pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1]
|
||||
|
||||
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
@ -621,22 +624,22 @@ class HopperWgmmaGemmKernel:
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
if cute.size(self.cluster_shape_mnk) > 1:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Generate smem tensor A/B
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
sa = storage.sa.get_tensor(
|
||||
sA = storage.sA.get_tensor(
|
||||
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
||||
)
|
||||
sb = storage.sb.get_tensor(
|
||||
sB = storage.sB.get_tensor(
|
||||
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
||||
)
|
||||
sc_ptr = cute.recast_ptr(
|
||||
sa.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
|
||||
sC_ptr = cute.recast_ptr(
|
||||
sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype
|
||||
)
|
||||
sc = cute.make_tensor(sc_ptr, epi_smem_layout_staged.outer)
|
||||
sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Local_tile partition global tensors
|
||||
@ -673,34 +676,34 @@ class HopperWgmmaGemmKernel:
|
||||
# TMA load A partition_S/D
|
||||
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
|
||||
a_cta_crd = cluster_coord_mnk[1]
|
||||
sa_for_tma_partition = cute.group_modes(sa, 0, 2)
|
||||
sA_for_tma_partition = cute.group_modes(sA, 0, 2)
|
||||
gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2)
|
||||
tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
a_cta_crd,
|
||||
a_cta_layout,
|
||||
sa_for_tma_partition,
|
||||
sA_for_tma_partition,
|
||||
gA_for_tma_partition,
|
||||
)
|
||||
|
||||
# TMA load B partition_S/D
|
||||
b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
|
||||
b_cta_crd = cluster_coord_mnk[0]
|
||||
sb_for_tma_partition = cute.group_modes(sb, 0, 2)
|
||||
sB_for_tma_partition = cute.group_modes(sB, 0, 2)
|
||||
gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2)
|
||||
tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
b_cta_crd,
|
||||
b_cta_layout,
|
||||
sb_for_tma_partition,
|
||||
sB_for_tma_partition,
|
||||
gB_for_tma_partition,
|
||||
)
|
||||
|
||||
# //////////////////////////////////////////////////////////////////////////////
|
||||
# Make frangments
|
||||
# Make fragments
|
||||
# //////////////////////////////////////////////////////////////////////////////
|
||||
tCsA = thr_mma.partition_A(sa)
|
||||
tCsB = thr_mma.partition_B(sb)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA)
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB)
|
||||
|
||||
@ -711,7 +714,7 @@ class HopperWgmmaGemmKernel:
|
||||
# Cluster wait
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# cluster wait for barrier init
|
||||
if cute.size(self.cluster_shape_mnk) > 1:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
cute.arch.sync_threads()
|
||||
@ -788,7 +791,7 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_tile in range(k_pipe_mmas):
|
||||
for k_tile in cutlass.range_constexpr(k_pipe_mmas):
|
||||
# Wait for A/B buffer to be ready
|
||||
mainloop_pipeline.consumer_wait(
|
||||
mainloop_consumer_read_state, peek_ab_full_status
|
||||
@ -917,7 +920,7 @@ class HopperWgmmaGemmKernel:
|
||||
# /////////////////////////////////////////////////////////////////////////////
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
if cute.size(self.cluster_shape_mnk) > 1:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
# Wait for all threads in the cluster to finish, avoid early release of smem
|
||||
cute.arch.cluster_arrive()
|
||||
cute.arch.cluster_wait()
|
||||
@ -950,33 +953,45 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sD = thr_copy_r2s.partition_D(sc)
|
||||
tRS_sD = thr_copy_r2s.partition_D(sC)
|
||||
# (R2S, R2S_M, R2S_N)
|
||||
tRS_rAcc = tiled_copy_r2s.retile(accumulators)
|
||||
|
||||
# Allocate D registers.
|
||||
rD_shape = cute.shape(thr_copy_r2s.partition_S(sc))
|
||||
rD_shape = cute.shape(thr_copy_r2s.partition_S(sC))
|
||||
tRS_rD_layout = cute.make_layout(rD_shape[:3])
|
||||
tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype)
|
||||
size_tRS_rD = cute.size(tRS_rD)
|
||||
|
||||
sepi_for_tma_partition = cute.group_modes(sc, 0, 2)
|
||||
tcgc_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
|
||||
sepi_for_tma_partition = cute.group_modes(sC, 0, 2)
|
||||
tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile)
|
||||
|
||||
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
sepi_for_tma_partition,
|
||||
tcgc_for_tma_partition,
|
||||
tCgC_for_tma_partition,
|
||||
)
|
||||
|
||||
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
|
||||
epi_tile_shape = tcgc_for_tma_partition.shape[1]
|
||||
epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1])
|
||||
epi_tile_shape = tCgC_for_tma_partition.shape[1]
|
||||
epi_tile_layout = cute.make_layout(
|
||||
epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
||||
)
|
||||
|
||||
for epi_idx in cutlass.range(epi_tile_num, unroll=epi_tile_num):
|
||||
# Initialize tma store c_pipeline
|
||||
c_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta
|
||||
)
|
||||
c_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.epi_stage,
|
||||
producer_group=c_producer_group,
|
||||
)
|
||||
|
||||
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
||||
# Copy from accumulators to D registers
|
||||
for epi_v in range(size_tRS_rD):
|
||||
for epi_v in cutlass.range_constexpr(size_tRS_rD):
|
||||
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
|
||||
|
||||
# Type conversion
|
||||
@ -997,10 +1012,6 @@ class HopperWgmmaGemmKernel:
|
||||
# barrier for sync
|
||||
cute.arch.barrier()
|
||||
|
||||
# Get the global memory coordinate for the current epi tile.
|
||||
epi_tile_layout = cute.make_layout(
|
||||
epi_tile_shape, stride=(epi_tile_shape[1], 1)
|
||||
)
|
||||
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
||||
# Copy from shared memory to global memory
|
||||
if warp_idx == 0:
|
||||
@ -1009,11 +1020,14 @@ class HopperWgmmaGemmKernel:
|
||||
bSG_sD[(None, epi_buffer)],
|
||||
bSG_gD[(None, gmem_coord)],
|
||||
)
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
||||
c_pipeline.producer_commit()
|
||||
c_pipeline.producer_acquire()
|
||||
|
||||
cute.arch.barrier()
|
||||
|
||||
if warp_idx == 0:
|
||||
c_pipeline.producer_tail()
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@ -1055,9 +1069,7 @@ class HopperWgmmaGemmKernel:
|
||||
mbar_helpers_bytes = 1024
|
||||
|
||||
ab_stage = (
|
||||
(smem_capacity - occupancy * 1024) // occupancy
|
||||
- mbar_helpers_bytes
|
||||
- epi_bytes
|
||||
smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes
|
||||
) // ab_bytes_per_stage
|
||||
return ab_stage, epi_stage
|
||||
|
||||
@ -1195,7 +1207,7 @@ class HopperWgmmaGemmKernel:
|
||||
def _compute_grid(
|
||||
c: cute.Tensor,
|
||||
tile_shape_mnk: tuple[int, int, int],
|
||||
cluster_shape_mnk: tuple[int, int, int],
|
||||
cluster_shape_mn: tuple[int, int],
|
||||
) -> tuple[int, int, int]:
|
||||
"""Compute grid shape for the output tensor C.
|
||||
|
||||
@ -1203,8 +1215,8 @@ class HopperWgmmaGemmKernel:
|
||||
:type c: cute.Tensor
|
||||
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
||||
:type tile_shape_mnk: tuple[int, int, int]
|
||||
:param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions.
|
||||
:type cluster_shape_mnk: tuple[int, int, int]
|
||||
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
||||
:type cluster_shape_mn: tuple[int, int]
|
||||
|
||||
:return: Grid shape for kernel launch.
|
||||
:rtype: tuple[int, int, int]
|
||||
@ -1212,8 +1224,9 @@ class HopperWgmmaGemmKernel:
|
||||
|
||||
c_shape = (tile_shape_mnk[0], tile_shape_mnk[1])
|
||||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||||
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk)
|
||||
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk))
|
||||
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
||||
clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl)
|
||||
grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl))
|
||||
return grid
|
||||
|
||||
@staticmethod
|
||||
@ -1363,7 +1376,7 @@ def run(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
tile_shape_mnk: Tuple[int, int, int],
|
||||
tile_shape_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
tolerance: float,
|
||||
warmup_iterations: int,
|
||||
@ -1387,8 +1400,8 @@ def run(
|
||||
:type acc_dtype: Type[cutlass.Numeric]
|
||||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||||
:type a_major/b_major/c_major: str
|
||||
:param tile_shape_mnk: CTA tile shape (M, N, K)
|
||||
:type tile_shape_mnk: Tuple[int, int, int]
|
||||
:param tile_shape_mn: CTA tile shape (M, N)
|
||||
:type tile_shape_mn: Tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster shape (M, N)
|
||||
:type cluster_shape_mn: Tuple[int, int]
|
||||
:param tolerance: Tolerance value for reference validation comparison
|
||||
@ -1411,7 +1424,7 @@ def run(
|
||||
f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
||||
)
|
||||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||||
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}")
|
||||
print(f"Tolerance: {tolerance}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Iterations: {iterations}")
|
||||
@ -1420,7 +1433,6 @@ def run(
|
||||
|
||||
# Unpack parameters
|
||||
m, n, k, l = mnkl
|
||||
cluster_shape_mnk = (*cluster_shape_mn, 1)
|
||||
|
||||
# Skip unsupported types
|
||||
if not HopperWgmmaGemmKernel.is_valid_dtypes(
|
||||
@ -1488,7 +1500,7 @@ def run(
|
||||
b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
|
||||
c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
|
||||
|
||||
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk)
|
||||
gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn)
|
||||
|
||||
torch_stream = torch.cuda.Stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
@ -1572,7 +1584,7 @@ if __name__ == "__main__":
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.tile_shape_mnk,
|
||||
args.tile_shape_mn,
|
||||
args.cluster_shape_mn,
|
||||
args.tolerance,
|
||||
args.warmup_iterations,
|
||||
|
||||
Reference in New Issue
Block a user