v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -90,18 +90,17 @@ If you already know the TV layout you want to use for your tiled copy, CuTe DSL
# Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN))
gA = cute.zipped_divide(mA, tiler_mn)
where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps
thread index and inter-thread index of data array per thread to logical coordinates of elements in
input and output tensors.
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility.
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which
infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread
block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per
thread to logical coordinates of elements in input and output tensors.
.. code-block:: python
blkA = gA[((None, None), bidx)] # (TileM,TileN)
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
# get slice of tiled_copy_A for current thread
thr_copy_A = tiled_copy_A.get_slice(tidx)
@ -140,8 +139,8 @@ def elementwise_add_kernel(
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
tv_layout: cute.Layout,
tiler_mn: cute.Shape,
thr_layout: cute.Layout,
val_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
@ -165,9 +164,9 @@ def elementwise_add_kernel(
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
@ -254,7 +253,7 @@ def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128):
cC = cute.zipped_divide(idC, tiler=tiler_mn)
print(f"[DSL INFO] coord tensor = {cC.type}")
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch(
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
)
@ -362,7 +361,7 @@ def run_elementwise_add(
workspace_generator=generate_tensors,
workspace_count=10,
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
)
# Print execution results

View File

@ -353,7 +353,7 @@ def run_elementwise_apply_and_verify(
current_stream,
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
use_cuda_graphs=True,
stream=current_stream,
)

View File

@ -32,13 +32,13 @@ from typing import Type, Union, Callable
import torch
import cuda.bindings.driver as cuda
import cutlass.cute.testing as testing
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, warp
import cutlass.torch as cutlass_torch
from cutlass.cute.runtime import from_dlpack
import cutlass.utils.ampere_helpers as sm80_utils
import cutlass.utils as utils
"""
A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL.
@ -163,7 +163,7 @@ class FlashAttentionForwardAmpere:
# Check if block size setting is out of shared memory capacity
# Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2
smem_capacity = sm80_utils.SMEM_CAPACITY["sm80"]
smem_capacity = utils.get_smem_capacity_in_bytes("sm_80")
if smem_usage > smem_capacity:
return False
@ -469,21 +469,9 @@ class FlashAttentionForwardAmpere:
warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),
self._dtype,
)
smem_tiled_copy_Q = cute.make_tiled_copy(
smem_copy_atom_Q,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
smem_tiled_copy_K = cute.make_tiled_copy(
smem_copy_atom_K,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
smem_tiled_copy_V = cute.make_tiled_copy(
smem_copy_atom_V,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma)
smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma)
smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma)
smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx)
smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx)
@ -702,11 +690,7 @@ class FlashAttentionForwardAmpere:
cute.nvgpu.CopyUniversalOp(), self._dtype
)
# tiled copy atom for O
smem_tiled_copy_O = cute.make_tiled_copy(
smem_copy_atom_O,
layout_tv=tiled_mma.tv_layout_C_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)),
)
smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma)
smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx)
taccOrO = smem_thr_copy_O.retile(rO)
taccOsO = smem_thr_copy_O.partition_D(sO)
@ -1178,7 +1162,7 @@ class FlashAttentionForwardAmpere:
return cute.arch.exp2(x)
def run_flash_attention_fwd(
def run(
dtype: Type[cutlass.Numeric],
batch_size: int,
seqlen_q: int,
@ -1193,6 +1177,8 @@ def run_flash_attention_fwd(
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
# Skip unsupported testcase
if not FlashAttentionForwardAmpere.can_implement(
@ -1207,6 +1193,23 @@ def run_flash_attention_fwd(
f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}"
)
print(f"Running Ampere SM80 FlashAttentionForward test with:")
print(f" dtype: {dtype}")
print(f" batch_size: {batch_size}")
print(f" seqlen_q: {seqlen_q}")
print(f" seqlen_k: {seqlen_k}")
print(f" num_head: {num_head}")
print(f" head_dim: {head_dim}")
print(f" softmax_scale: {softmax_scale}")
print(f" m_block_size: {m_block_size}")
print(f" n_block_size: {n_block_size}")
print(f" num_threads: {num_threads}")
print(f" is_causal: {is_causal}")
print(f" warmup_iterations: {warmup_iterations}")
print(f" iterations: {iterations}")
print(f" skip_ref_check: {skip_ref_check}")
print(f" use_cold_l2: {use_cold_l2}")
# Create tensor Q/K/V/O
def create_tensor(
batch_size: int,
@ -1217,22 +1220,28 @@ def run_flash_attention_fwd(
) -> cute.Tensor:
# (batch_size, seqlen, num_head, head_dim)
shape = (batch_size, seqlen, num_head, head_dim)
return (
torch.empty(*shape, dtype=torch.int32).random_(-2, 2).to(dtype=dtype).cuda()
torch_tensor = (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(dtype=cutlass_torch.dtype(dtype))
.cuda()
)
# assume input is 16B aligned.
cute_tensor = (
from_dlpack(torch_tensor, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3,
stride_order=torch_tensor.dim_order(),
divisibility=(128 // dtype.width),
)
)
return cute_tensor, torch_tensor
q = create_tensor(
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
)
k = create_tensor(
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
)
v = create_tensor(
batch_size, seqlen_k, num_head, head_dim, cutlass_torch.dtype(dtype)
)
o = create_tensor(
batch_size, seqlen_q, num_head, head_dim, cutlass_torch.dtype(dtype)
)
q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
fa2_fwd = FlashAttentionForwardAmpere(
head_dim,
@ -1241,78 +1250,63 @@ def run_flash_attention_fwd(
num_threads,
is_causal,
)
# assume input is 16B align.
q_tensor = (
from_dlpack(q, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=q.dim_order(), divisibility=(128 // dtype.width)
)
)
k_tensor = (
from_dlpack(k, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=k.dim_order(), divisibility=(128 // dtype.width)
)
)
v_tensor = (
from_dlpack(v, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=v.dim_order(), divisibility=(128 // dtype.width)
)
)
o_tensor = (
from_dlpack(o, assumed_align=16)
.mark_layout_dynamic(leading_dim=3)
.mark_compact_shape_dynamic(
mode=3, stride_order=o.dim_order(), divisibility=(128 // dtype.width)
)
)
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# compile the fa2 forward pass
compiled_fa2_fwd = cute.compile(
fa2_fwd, q_tensor, k_tensor, v_tensor, o_tensor, softmax_scale, current_stream
compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream)
if not skip_ref_check:
compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream)
torch.cuda.synchronize()
q_ref = q_torch.permute(0, 2, 1, 3)
k_ref = k_torch.permute(0, 2, 1, 3)
v_ref = v_torch.permute(0, 2, 1, 3)
torch.backends.cuda.enable_flash_sdp(enabled=True)
ref_o = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
).permute(0, 2, 1, 3)
torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
print("Results verified successfully!")
def generate_tensors():
q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype)
o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype)
return testing.JitArguments(
q_workspace,
k_workspace,
v_workspace,
o_workspace,
softmax_scale,
current_stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
q_torch.numel() * q_torch.element_size()
+ k_torch.numel() * k_torch.element_size()
+ v_torch.numel() * v_torch.element_size()
+ o_torch.numel() * o_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
compiled_fa2_fwd,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
# warmup
for _ in range(warmup_iterations):
compiled_fa2_fwd(
q_tensor,
k_tensor,
v_tensor,
o_tensor,
softmax_scale,
current_stream,
)
# run the compiled fa2 forward pass
for _ in range(iterations):
compiled_fa2_fwd(
q_tensor,
k_tensor,
v_tensor,
o_tensor,
softmax_scale,
current_stream,
)
torch.cuda.synchronize()
if skip_ref_check:
return
# reference implementation
q_ref = q.permute(0, 2, 1, 3)
k_ref = k.permute(0, 2, 1, 3)
v_ref = v.permute(0, 2, 1, 3)
torch.backends.cuda.enable_flash_sdp(enabled=True)
ref_o = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal
).permute(0, 2, 1, 3)
torch.testing.assert_close(o.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04)
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -1334,9 +1328,15 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference check"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
run_flash_attention_fwd(
run(
args.dtype,
args.batch_size,
args.seqlen_q,
@ -1348,6 +1348,10 @@ if __name__ == "__main__":
args.n_block_size,
args.num_threads,
args.is_causal,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -634,16 +634,50 @@ class SGemm:
return
def main(
def run(
mnk: Tuple[int, int, int],
a_major: str,
b_major: str,
c_major: str,
problem_shape: Tuple[int, int, int],
static_shape: bool = False,
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
M, N, K = problem_shape
"""Execute SIMT GEMM operation and benchmark performance.
:param mnk: GEMM problem size (M, N, K, L)
:type mnk: Tuple[int, int, int, int]
:param a_major: Memory layout of tensor A
:type a_major: str
:param b_major: Memory layout of tensor B
:type b_major: str
:param c_major: Memory layout of tensor C
:type c_major: str
:param static_shape: Whether to use static shape optimization, defaults to False
:type static_shape: bool, optional
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 100
:type iterations: int, optional
:param skip_ref_check: Skip validation against reference implementation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Ampere SIMT GEMM example:")
print(f"mnk: {mnk}")
print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}")
print(f"Static shape: {static_shape}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
M, N, K = mnk
# Create and permute tensor A/B/C
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
@ -710,20 +744,6 @@ def main(
print("Executing GEMM kernel...")
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(
a_tensor, b_tensor, c_tensor, current_stream
),
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
use_cuda_graphs=False,
stream=current_stream,
)
# Print execution results
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
torch.cuda.synchronize()
@ -732,6 +752,71 @@ def main(
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
# Create new tensors for each workspace to ensure cold L2 cache
a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
if static_shape:
a_tensor_workspace = (
from_dlpack(a_workspace, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if a_major == "k" else 0),
divisibility=divisibility_a,
)
)
else:
a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16)
b_tensor_workspace = (
from_dlpack(b_workspace, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if b_major == "k" else 0),
divisibility=divisibility_b,
)
)
c_tensor_workspace = (
from_dlpack(c_workspace, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
.mark_compact_shape_dynamic(
mode=(1 if c_major == "n" else 0),
divisibility=divisibility_c,
)
)
return testing.JitArguments(
a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a.numel() * a.element_size()
+ b.numel() * b.element_size()
+ c.numel() * c.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
# Print execution results
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
@ -753,19 +838,27 @@ if __name__ == "__main__":
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
print("Running SIMT GEMM example:")
torch.manual_seed(1024)
main(
run(
args.mnk,
args.a_major,
args.b_major,
args.c_major,
args.mnk,
args.static_shape,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -51,7 +51,7 @@ This GEMM kernel supports the following features:
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
- Threadblock rasterization to improve data re-use
- Supports multi-stage pipeline to overlap computation and memory access
- Implements shared memory buffering for epilogue to increase coalesed global memory access
- Implements shared memory buffering for epilogue to increase coalesced global memory access
This GEMM works as follows:
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
@ -214,7 +214,7 @@ class TensorOpGemm:
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
)
# Creates a synchonous copy atom and thread layouts for the epilogue
# Creates a synchronous copy atom and thread layouts for the epilogue
c_copy_bits = 128
atom_sync_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
@ -550,16 +550,8 @@ class TensorOpGemm:
# Creates the tiled copy so that it matches the thread-value layout
# expected by the tiled mma
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=tiled_mma.tv_layout_A_tiled,
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=tiled_mma.tv_layout_B_tiled,
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
)
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
@ -836,8 +828,7 @@ class TensorOpGemm:
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((copy_elems, 1))
)
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
def raster_tile(self, i, j, f):
new_i = i // f
@ -845,20 +836,33 @@ class TensorOpGemm:
return (new_i, new_j)
def run_tensor_op_gemm(
def run(
a_major: str,
b_major: str,
c_major: str,
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
problem_shape: Tuple[int, int, int, int],
mnkl: Tuple[int, int, int, int],
atom_layout_mnk: Tuple[int, int, int],
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
M, N, K, L = problem_shape
print(f"Running Ampere tensor core GEMM example:")
print(f"mnkl: {mnkl}")
print(
f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Atoms layout: {atom_layout_mnk}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
M, N, K, L = mnkl
# Create and permute tensor A/B/C
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
@ -866,23 +870,28 @@ def run_tensor_op_gemm(
# else: (l, mode0, mode1) -> (mode0, mode1, l)
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
return (
torch_tensor = (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(dtype=dtype)
.to(dtype=cutlass_torch.dtype(dtype))
.permute(permute_order)
.cuda()
)
# assume input is 16B aligned
cute_tensor = (
from_dlpack(torch_tensor, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0))
.mark_compact_shape_dynamic(
mode=(1 if not is_mode0_major else 0),
stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0),
divisibility=(128 // dtype.width),
)
)
return cute_tensor, torch_tensor
a = create_and_permute_tensor(
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
)
b = create_and_permute_tensor(
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
)
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
tensor_op_gemm = TensorOpGemm(
ab_dtype,
@ -891,56 +900,49 @@ def run_tensor_op_gemm(
atom_layout_mnk,
)
# assume input is 16B aligned
a_tensor = (
from_dlpack(a, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if a_major == "k" else 0),
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
divisibility=(128 // ab_dtype.width),
)
)
b_tensor = (
from_dlpack(b, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if b_major == "k" else 0),
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
divisibility=(128 // ab_dtype.width),
)
)
c_tensor = (
from_dlpack(c, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
.mark_compact_shape_dynamic(
mode=(1 if c_major == "n" else 0),
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
divisibility=(128 // c_dtype.width),
)
)
print("Compiling kernel with cute.compile ...")
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC)
print("Executing GEMM kernel...")
if not skip_ref_check:
ref = torch.einsum(
"mkl,nkl->mnl",
a_torch.to(dtype=torch.float32),
b_torch.to(dtype=torch.float32),
).to(cutlass_torch.dtype(c_dtype))
compiled_gemm(mA, mB, mC)
print("Verifying results...")
torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype)
b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype)
c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype)
return testing.JitArguments(a_workspace, b_workspace, c_workspace)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
avg_time_us = testing.benchmark(
gemm,
kernel_arguments=testing.JitArguments(a_tensor, b_tensor, c_tensor),
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
warmup_iterations=warmup_iterations,
profiling_iterations=iterations,
iterations=iterations,
use_cuda_graphs=False,
)
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
if not skip_ref_check:
gemm(a_tensor, b_tensor, c_tensor)
print("Verifying results...")
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
print("Results verified successfully!")
return avg_time_us # Return execution time in microseconds
if __name__ == "__main__":
@ -985,10 +987,15 @@ if __name__ == "__main__":
parser.add_argument("--warmup_iterations", default=2, type=int)
parser.add_argument("--iterations", default=100, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
print("Running Ampere tensor core GEMM example:")
run_tensor_op_gemm(
run(
args.a_major,
args.b_major,
args.c_major,
@ -1000,5 +1007,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

File diff suppressed because it is too large Load Diff

View File

@ -212,7 +212,7 @@ class DenseGemmKernel:
self.occupancy = 1
self.threads_per_cta = 128
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1106,11 +1106,7 @@ class DenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1772,7 +1768,7 @@ def run_dense_gemm(
ref_c = ref
elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
# m major: (l, n, m) -> (m, n, l)
# k major: (l, m, n) -> (m, n, l)
# n major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(

View File

@ -38,6 +38,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.cute.testing as testing
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
@ -226,7 +227,7 @@ class PersistentDenseGemmKernel:
self.cta_sync_bar_id = 0
self.epilog_sync_bar_id = 1
self.tmem_ptr_sync_bar_id = 2
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1308,11 +1309,7 @@ class PersistentDenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1824,7 +1821,7 @@ class PersistentDenseGemmKernel:
return can_implement
def run_dense_gemm(
def run(
mnkl: Tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
@ -1832,17 +1829,58 @@ def run_dense_gemm(
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_2cta_instrs: bool,
use_tma_store: bool,
tolerance: float,
mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Tuple[int, int] = (2, 1),
use_2cta_instrs: bool = True,
use_tma_store: bool = True,
tolerance: float = 1e-01,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
"""Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
This function prepares input tensors, configures and launches the persistent GEMM kernel,
optionally performs reference validation, and benchmarks the execution performance.
:param mnkl: Problem size (M, N, K, L)
:type mnkl: Tuple[int, int, int, int]
:param ab_dtype: Data type for input tensors A and B
:type ab_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Data type for accumulation during matrix multiplication
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the
default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters.
:type mma_tiler_mn: Tuple[int, int], optional
:param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the
default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters.
:type cluster_shape_mn: Tuple[int, int], optional
:param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner
will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
:type use_2cta_instrs: bool, optional
:param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use
the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters.
:type use_tma_store: bool, optional
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
:type tolerance: float, optional
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:raises RuntimeError: If CUDA GPU is not available
:raises ValueError: If the configuration is invalid or unsupported by the kernel
:return: Execution time of the GEMM kernel
:rtype: float
"""
print(f"Running Blackwell Persistent Dense GEMM test with:")
print(f"mnkl: {mnkl}")
@ -1855,6 +1893,7 @@ def run_dense_gemm(
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Unpack parameters
m, n, k, l = mnkl
@ -1931,15 +1970,15 @@ def run_dense_gemm(
is_dynamic_layout=is_dynamic_layout,
)
return f32_torch_tensor, cute_tensor, torch_tensor
return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu
a_ref, a_tensor, a_torch = create_and_permute_tensor(
a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor(
l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True
)
b_ref, b_tensor, b_torch = create_and_permute_tensor(
b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor(
l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True
)
c_ref, c_tensor, c_torch = create_and_permute_tensor(
c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor(
l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True
)
@ -1967,16 +2006,8 @@ def run_dense_gemm(
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream
)
# Launch GPU kernel
# Warm up
for i in range(warmup_iterations):
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
# Execution
for i in range(iterations):
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
# Compute reference result
if not skip_ref_check:
compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream)
if ab_dtype in {
cutlass.Int8,
cutlass.Uint8,
@ -2028,6 +2059,40 @@ def run_dense_gemm(
rtol=1e-05,
)
def generate_tensors():
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
c_tensor, _ = cutlass_torch.cute_tensor_like(
c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16
)
return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ c_torch_cpu.numel() * c_torch_cpu.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -2090,6 +2155,12 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -2102,7 +2173,7 @@ if __name__ == "__main__":
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
run_dense_gemm(
run(
args.mnkl,
args.ab_dtype,
args.c_dtype,
@ -2118,5 +2189,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -223,7 +223,7 @@ class DenseGemmKernel:
self.occupancy = 1
self.threads_per_cta = 128
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
@ -1063,11 +1063,7 @@ class DenseGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)

View File

@ -43,6 +43,7 @@ import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
@ -90,7 +91,7 @@ Constraints for this example:
* Number of heads in Q must be divisible by number of heads in K
* mma_tiler_mn must be 128,128
* Batch size must be the same for Q, K, and V tensors
* For causal masking, use --has_casual_mask (note: specify without =True/False)
* For causal masking, use --is_causal (note: specify without =True/False)
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
"""
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
smem_copy_atom = sm100_utils.get_smem_store_op(
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
)
tiled_smem_store = cute.make_tiled_copy(
smem_copy_atom,
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
tiler_mn=tiled_tmem_load.tiler_mn,
)
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
return tile_sched_params, grid
def run_fmha_and_verify(
def run(
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
in_dtype: Type[cutlass.Numeric],
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
pv_acc_dtype: Type[cutlass.Numeric],
mma_tiler_mn: Tuple[int, int],
is_persistent: bool,
has_casual_mask: bool,
is_causal: bool,
scale_q: float,
scale_k: float,
scale_v: float,
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
:type mma_tiler_mn: Tuple[int, int]
:param is_persistent: Whether to use persistent kernel optimization
:type is_persistent: bool
:param has_casual_mask: Whether to apply causal masking
:type has_casual_mask: bool
:param is_causal: Whether to apply causal masking
:type is_causal: bool
:param scale_q: Scaling factor for query tensor
:type scale_q: float
:param scale_k: Scaling factor for key tensor
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
:type iterations: int
:param skip_ref_check: Skip validation against reference implementation
:type skip_ref_check: bool
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
:type use_cold_l2: bool
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
:raises RuntimeError: If GPU is unavailable for computation
:return: Execution time of the FMHA kernel in microseconds
:rtype: float
"""
print(f"Running Blackwell SM100 FMHA test with:")
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
print(f" pv_acc_dtype: {pv_acc_dtype}")
print(f" mma_tiler_mn: {mma_tiler_mn}")
print(f" is_persistent: {is_persistent}")
print(f" has_casual_mask: {has_casual_mask}")
print(f" is_causal: {is_causal}")
print(f" scale_q: {scale_q}")
print(f" scale_k: {scale_k}")
print(f" scale_v: {scale_v}")
print(f" inv_scale_o: {inv_scale_o}")
print(f" scale_softmax: {scale_softmax}")
print(f" tolerance: {tolerance}")
print(f" warmup_iterations: {warmup_iterations}")
print(f" iterations: {iterations}")
print(f" skip_ref_check: {skip_ref_check}")
print(f" use_cold_l2: {use_cold_l2}")
# Unpack parameters
b, s_q, h_q, d = q_shape
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
mma_tiler = (*mma_tiler_mn, d)
mask_type = MaskType.NO_MASK
if has_casual_mask:
if is_causal:
mask_type = MaskType.CAUSAL_MASK
else:
if isinstance(s_k, tuple):
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
# Warmup
for _ in range(warmup_iterations):
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
# Execute kernel
for _ in range(iterations):
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
torch.cuda.synchronize()
def run_torch_fmha(
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
):
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
h_q = q.shape[2]
h_k = k.shape[2]
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
v = v.transpose(1, 2)
# For the situation that torch has not supported, we need to handle it manually
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
situation1 = is_causal and (q.is_nested or k.is_nested)
situation2 = (q.is_nested and not k.is_nested) or (
not q.is_nested and k.is_nested
)
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
attn_mask=None,
dropout_p=0.0,
scale=scale_softmax,
is_causal=has_casual_mask,
is_causal=is_causal,
)
ref_i = ref_i.transpose(0, 1) * scale_output
ref_list.append(ref_i)
if q.is_nested:
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
attn_mask=None,
dropout_p=0.0,
scale=scale_softmax,
is_causal=has_casual_mask,
is_causal=is_causal,
)
ref = ref.transpose(1, 2) * scale_output
ref = ref.transpose(1, 2) * scale_output
return ref
if not skip_ref_check:
# Execute kernel once for reference checking
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
print("Verifying results...")
o_ref = run_torch_fmha(
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
)
if o_ref.is_nested:
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
_, q_tensor_workspace, _ = create_and_pad_tensor(
qo_shape,
qo_padding,
in_dtype,
s_cumsum=cum_seqlen_q_torch,
is_dynamic_layout=True,
)
_, k_tensor_workspace, _ = create_and_pad_tensor(
kv_shape,
kv_padding,
in_dtype,
s_cumsum=cum_seqlen_k_torch,
is_dynamic_layout=True,
)
_, v_tensor_workspace, _ = create_and_pad_tensor(
kv_shape,
kv_padding,
in_dtype,
s_cumsum=cum_seqlen_k_torch,
is_dynamic_layout=True,
)
_, o_tensor_workspace, _ = create_and_pad_tensor(
qo_shape,
qo_padding,
out_dtype,
s_cumsum=cum_seqlen_q_torch,
is_dynamic_layout=True,
)
return testing.JitArguments(
q_tensor_workspace.iterator,
k_tensor_workspace.iterator,
v_tensor_workspace.iterator,
o_tensor_workspace.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
workspace_count = 1
if use_cold_l2:
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
one_workspace_bytes = (
q_torch_effective.numel() * q_torch_effective.element_size()
+ k_torch_effective.numel() * k_torch_effective.element_size()
+ v_torch_effective.numel() * v_torch_effective.element_size()
+ o_torch_effective.numel() * o_torch_effective.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_fmha,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
def parse_comma_separated_ints(s: str):
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--has_casual_mask",
"--is_causal",
action="store_true",
help="Whether to use casual mask",
)
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
help="Skip reference check",
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
if len(args.q_shape) != 4:
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
torch.manual_seed(1111)
run_fmha_and_verify(
run(
args.q_shape,
args.k_shape,
args.in_dtype,
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
args.pv_acc_dtype,
args.mma_tiler_mn,
args.is_persistent,
args.has_casual_mask,
args.is_causal,
args.scale_q,
args.scale_k,
args.scale_v,
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -36,6 +36,7 @@ import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
@ -157,7 +158,7 @@ class GroupedGemmKernel:
self.tmem_ptr_sync_bar_id = 2
# Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion
self.tensormap_ab_init_bar_id = 4
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
self.num_tma_load_bytes = 0
def _setup_attributes(self):
@ -951,7 +952,7 @@ class GroupedGemmKernel:
# Specialized MMA warp
#
if warp_idx == self.mma_warp_id:
# initilize tensormap A, B for TMA warp
# initialize tensormap A, B for TMA warp
if cutlass.const_expr(self.delegate_tensormap_ab_init):
tensormap_manager.init_tensormap_from_atom(
tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id
@ -1540,11 +1541,7 @@ class GroupedGemmKernel:
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy(
copy_atom_r2s,
layout_tv=tiled_copy_t2r.layout_dst_tv_tiled,
tiler_mn=tiled_copy_t2r.tiler_mn,
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
@ -1815,7 +1812,136 @@ class GroupedGemmKernel:
tensor_memory_management_bytes = 12
def run_grouped_gemm(
# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
dtype: type[cutlass.Numeric],
is_dynamic_layout: bool = True,
torch_tensor_cpu: torch.Tensor = None,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
"""Create a GPU tensor from scratch or based on an existing CPU tensor.
:param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
:type torch_tensor_cpu: torch.Tensor, optional
"""
if torch_tensor_cpu is None:
# Create new CPU tensor
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
# Create GPU tensor from CPU tensor (new or existing)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
torch_tensor_cpu,
torch_tensor.stride()[:-1],
)
def create_tensors_for_all_groups(
problem_sizes_mnkl: List[tuple[int, int, int, int]],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
torch_fp32_tensors_abc: List[List[torch.Tensor]] = None,
) -> tuple[
List[List[int]],
List[List[torch.Tensor]],
List[tuple],
List[List[tuple]],
List[List[torch.Tensor]],
]:
if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len(
problem_sizes_mnkl
):
raise ValueError("torch_fp32_tensors_abc must have one entry per group")
# Initialize lists to store tensors for all groups
new_torch_fp32_tensors_abc = (
[] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc
)
torch_tensors_abc = []
cute_tensors_abc = []
strides_abc = []
ptrs_abc = []
# Iterate through all groups and create tensors for each group
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
# Get existing CPU tensors if available, otherwise None
existing_cpu_a = (
torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None
)
existing_cpu_b = (
torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None
)
existing_cpu_c = (
torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None
)
# Create tensors (reusing CPU tensors if provided)
(
ptr_a,
torch_tensor_a,
cute_tensor_a,
tensor_fp32_a,
stride_mk_a,
) = create_tensor_and_stride(
l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a
)
(
ptr_b,
torch_tensor_b,
cute_tensor_b,
tensor_fp32_b,
stride_nk_b,
) = create_tensor_and_stride(
l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b
)
(
ptr_c,
torch_tensor_c,
cute_tensor_c,
tensor_fp32_c,
stride_mn_c,
) = create_tensor_and_stride(
l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c
)
# Only append to new_torch_fp32_tensors_abc if we created new CPU tensors
if torch_fp32_tensors_abc is None:
new_torch_fp32_tensors_abc.append(
[tensor_fp32_a, tensor_fp32_b, tensor_fp32_c]
)
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
cute_tensors_abc.append(
(
cute_tensor_a,
cute_tensor_b,
cute_tensor_c,
)
)
return (
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
new_torch_fp32_tensors_abc,
)
def run(
num_groups: int,
problem_sizes_mnkl: tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
@ -1832,8 +1958,16 @@ def run_grouped_gemm(
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""Run grouped GEMM example with specified configurations."""
"""Run grouped GEMM example with specified configurations.
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Blackwell Grouped GEMM test with:")
print(f"{num_groups} groups")
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
@ -1847,6 +1981,7 @@ def run_grouped_gemm(
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Skip unsupported types
if ab_dtype not in {
@ -1902,66 +2037,22 @@ def run_grouped_gemm(
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run this example!")
# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
dtype: type[cutlass.Numeric],
is_dynamic_layout: bool = True,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
torch_tensor_cpu,
torch_tensor.stride()[:-1],
)
# Create tensors for all groups using the new function
(
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
torch_fp32_tensors_abc,
) = create_tensors_for_all_groups(
problem_sizes_mnkl,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
)
# iterate all groups and create tensors for each group
torch_fp32_tensors_abc = []
torch_tensors_abc = []
cute_tensors_abc = []
strides_abc = []
ptrs_abc = []
for _, (m, n, k, l) in enumerate(problem_sizes_mnkl):
(
ptr_a,
torch_tensor_a,
cute_tensor_a,
tensor_fp32_a,
stride_mk_a,
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
(
ptr_b,
torch_tensor_b,
cute_tensor_b,
tensor_fp32_b,
stride_nk_b,
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
(
ptr_c,
torch_tensor_c,
cute_tensor_c,
tensor_fp32_c,
stride_mn_c,
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
torch_fp32_tensors_abc.append([tensor_fp32_a, tensor_fp32_b, tensor_fp32_c])
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
cute_tensors_abc.append(
(
cute_tensor_a,
cute_tensor_b,
cute_tensor_c,
)
)
# Choose A, B, C with the smallest size to create initial tensormaps
key_size_a = lambda item: item[1][0] * item[1][2]
key_size_b = lambda item: item[1][1] * item[1][2]
@ -2078,36 +2169,19 @@ def run_grouped_gemm(
current_stream,
)
# Launch GPU kernel
# Warm up
for _ in range(warmup_iterations):
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
# Execution
for i in range(iterations):
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
torch.cuda.synchronize()
# Compute reference result
if not skip_ref_check:
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
# Compute reference result
for i, (a, b, c) in enumerate(torch_tensors_abc):
ref = torch.einsum(
"mkl,nkl->mnl",
@ -2122,6 +2196,102 @@ def run_grouped_gemm(
rtol=1e-05,
)
def generate_tensors():
# Reuse existing CPU tensors and create new GPU tensors from them
(
ptrs_abc_workspace,
torch_tensors_abc_workspace,
cute_tensors_abc_workspace,
strides_abc_workspace,
_,
) = create_tensors_for_all_groups(
problem_sizes_mnkl,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
torch_fp32_tensors_abc,
)
initial_cute_tensors_abc_workspace = [
cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k)
cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k)
cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n)
]
# Create new tensors for this workspace
tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like(
torch.tensor(strides_abc_workspace, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_abc_workspace, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
assumed_align=16,
)
tensormap_workspace, _ = cutlass_torch.cute_tensor_like(
torch.empty(tensormap_shape, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
)
return testing.JitArguments(
initial_cute_tensors_abc_workspace[0],
initial_cute_tensors_abc_workspace[1],
initial_cute_tensors_abc_workspace[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc_workspace,
tensor_of_ptrs_abc_workspace,
tensormap_workspace,
current_stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
sum(
[
sum(
[
torch_tensor.numel() * torch_tensor.element_size()
for torch_tensor in group_tensors
]
)
for group_tensors in torch_tensors_abc
]
)
+
# Add size of strides tensor
tensor_of_strides_abc_torch.numel()
* tensor_of_strides_abc_torch.element_size()
+
# Add size of ptrs tensor
tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size()
+
# Add size of tensormap tensor
tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_grouped_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -2218,6 +2388,12 @@ if __name__ == "__main__":
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -2248,7 +2424,7 @@ if __name__ == "__main__":
torch.manual_seed(2025)
run_grouped_gemm(
run(
args.num_groups,
args.problem_sizes_mnkl,
args.ab_dtype,
@ -2265,5 +2441,6 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -29,13 +29,14 @@
import argparse
from typing import List, Type, Tuple, Optional
from cuda import cuda
import cuda.bindings.driver as cuda
import torch
import torch.nn.functional as F
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
@ -43,13 +44,16 @@ import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
from .mamba2_ssd_reference import (
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent))
from mamba2_ssd_reference import (
ssd_reference_fp32_all,
ssd_reference_lowprecision_intermediates,
analyze_relative_diffs,
)
from .mamba2_ssd_tile_scheduler import (
from mamba2_ssd_tile_scheduler import (
Mamba2SSDTileSchedulerParams,
Mamba2SSDTileScheduler,
)
@ -122,7 +126,7 @@ class SSDKernel:
*self.epilog_warp_id,
)
)
self.smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
# Named barriers
self.pre_inter_sync_bar_id = 1
@ -1522,7 +1526,10 @@ class SSDKernel:
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N)
# ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE)
tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b(
local_tidx, smem_bt_internal_, tiled_s2r_b, tBrB_s2r
local_tidx,
smem_bt_internal_,
tiled_s2r_b,
tBrB_s2r,
)
# (MMA, MMA_M, MMA_K, INPUT_STAGE)
@ -3053,7 +3060,7 @@ class SSDKernel:
# SegSum
# fadd2 + fsel + fmul2/mufu + fmul2
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
(
tCompute[subtile_idx],
tCompute[subtile_idx + 1],
@ -3061,11 +3068,11 @@ class SSDKernel:
(tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]),
(-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]),
)
for subtile_idx in range(cute.size(tTR_rQ)):
for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True):
m, n = tCoord[subtile_idx]
if m < n:
tCompute[subtile_idx] = cutlass.Float32(-float("inf"))
for subtile_idx in range(0, cute.size(tTR_rQ), 2):
for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True):
# TODO: use math.exp directly
(
tCompute[subtile_idx],
@ -3130,11 +3137,7 @@ class SSDKernel:
dtype,
num_bits_per_copy=128,
)
tiled_r2s_b = cute.make_tiled_copy(
copy_atom_r2s_b,
layout_tv=tiled_s2r_b.layout_tv_tiled,
tiler_mn=tiled_s2r_b.tiler_mn,
)
tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b)
thr_r2s_b = tiled_r2s_b.get_slice(local_tidx)
# Partition shared tensor for smem store Bt
@ -3333,17 +3336,24 @@ class SSDKernel:
)
def run_ssd(
def run(
gbehcdln: Tuple[int, int, int, int, int, int, int, int],
io_dtype: Type[cutlass.Numeric],
cumsum_delta_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
has_d: bool,
d_has_hdim: bool,
fuse_scale_d: str,
tolerance: float,
print_rtol_stats: bool,
ref_lower_precision: bool,
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
has_d = fuse_scale_d != "none"
d_has_hdim = fuse_scale_d == "vector"
print(f"Running B100 Mamba2 SSD with:")
print(f"GBEHCDLN: {gbehcdln}")
print(
@ -3353,6 +3363,10 @@ def run_ssd(
f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}"
)
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
# Unpack parameters
G, B, E, H, C, D, L, N = gbehcdln
@ -3515,39 +3529,146 @@ def run_ssd(
stream,
)
# Launch compiled ssd kernel
compiled_ssd(
x_tensor,
cumsum_delta_tensor,
delta_tensor,
b_tensor,
c_tensor,
y_tensor,
fstate_tensor,
d_tensor,
stream,
# Launch compiled ssd kernel for reference check
if not skip_ref_check:
compiled_ssd(
x_tensor,
cumsum_delta_tensor,
delta_tensor,
b_tensor,
c_tensor,
y_tensor,
fstate_tensor,
d_tensor,
stream,
)
# Reference check
if print_rtol_stats:
print("\nY's Relative diffs:")
analyze_relative_diffs(
y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype))
)
print("\nFstate's Relative diffs:")
analyze_relative_diffs(
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
)
torch.testing.assert_close(
y_torch.cpu(),
y_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-02,
)
torch.testing.assert_close(
fstate_torch.cpu(),
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-05,
)
def generate_tensors():
# Reuse existing CPU reference tensors and create new GPU tensors from them
_, x_tensor_new, _ = create_and_permute_tensor(
[B, EH, D, C, L],
[2, 4, 3, 1, 0],
io_dtype,
ref_tensor=x_ref,
dynamic_modes=[2, 3, 4],
)
_, cumsum_delta_tensor_new, _ = create_and_permute_tensor(
[B, EH, C, L],
[3, 2, 1, 0],
cumsum_delta_dtype,
ref_tensor=cumsum_delta_ref,
dynamic_modes=[1, 2, 3],
)
_, delta_tensor_new, _ = create_and_permute_tensor(
[B, EH, C, L],
[3, 2, 1, 0],
io_dtype,
ref_tensor=delta_ref,
dynamic_modes=[1, 2, 3],
)
_, b_tensor_new, _ = create_and_permute_tensor(
[B, G, N, C, L],
[4, 2, 3, 1, 0],
io_dtype,
ref_tensor=b_ref,
dynamic_modes=[2, 3, 4],
)
_, c_tensor_new, _ = create_and_permute_tensor(
[B, G, N, C, L],
[4, 2, 3, 1, 0],
io_dtype,
ref_tensor=c_ref,
dynamic_modes=[2, 3, 4],
)
_, y_tensor_new, _ = create_and_permute_tensor(
[B, EH, D, C, L],
[4, 2, 3, 1, 0],
io_dtype,
ref_tensor=y_ref,
dynamic_modes=[2, 3, 4],
)
_, fstate_tensor_new, _ = create_and_permute_tensor(
[B, EH, D, N],
[2, 3, 1, 0],
io_dtype,
ref_tensor=fstate_ref,
dynamic_modes=[2, 3],
)
if has_d:
_, d_tensor_new, _ = create_and_permute_tensor(
[EH, D if d_has_hdim else 1],
[1, 0],
io_dtype,
ref_tensor=d_ref,
dynamic_modes=[1],
)
else:
d_tensor_new = d_tensor
return testing.JitArguments(
x_tensor_new,
cumsum_delta_tensor_new,
delta_tensor_new,
b_tensor_new,
c_tensor_new,
y_tensor_new,
fstate_tensor_new,
d_tensor_new,
stream,
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
x_torch.numel() * x_torch.element_size()
+ cumsum_delta_torch.numel() * cumsum_delta_torch.element_size()
+ delta_torch.numel() * delta_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
+ y_torch.numel() * y_torch.element_size()
+ fstate_torch.numel() * fstate_torch.element_size()
)
if has_d:
one_workspace_bytes += d_torch.numel() * d_torch.element_size()
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_ssd,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
# Reference check
if print_rtol_stats:
print("\nY's Relative diffs:")
analyze_relative_diffs(y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)))
print("\nFstate's Relative diffs:")
analyze_relative_diffs(
fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype))
)
torch.testing.assert_close(
y_torch.cpu(),
y_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-02,
)
torch.testing.assert_close(
fstate_torch.cpu(),
fstate_ref.to(cutlass_torch.dtype(io_dtype)),
atol=tolerance,
rtol=1e-05,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
@ -3586,15 +3707,53 @@ if __name__ == "__main__":
)
parser.add_argument(
"--ref_lower_precision",
type=bool,
action="store_true",
default=True,
help="Use lower precision for reference check",
)
parser.add_argument(
"--no-ref_lower_precision",
action="store_false",
dest="ref_lower_precision",
default=False,
help="Disable lower precision for reference check",
)
parser.add_argument(
"--tolerance", type=float, default=5e-02, help="Tolerance for validation"
)
parser.add_argument(
"--print_rtol_stats", type=bool, default=True, help="Print rtol stats"
"--print_rtol_stats",
action="store_true",
default=True,
help="Enable print rtol stats",
)
parser.add_argument(
"--no-print_rtol_stats",
action="store_false",
dest="print_rtol_stats",
default=False,
help="Disable print rtol stats",
)
parser.add_argument(
"--warmup_iterations",
type=int,
default=0,
help="Number of warmup iterations",
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -3602,18 +3761,18 @@ if __name__ == "__main__":
if len(args.gbehcdln) != 8:
parser.error("--gbehcdln must contain exactly 8 values")
has_d = args.fuse_scale_d != "none"
d_has_hdim = args.fuse_scale_d == "vector"
run_ssd(
run(
args.gbehcdln,
args.io_dtype,
args.cumsum_delta_dtype,
args.acc_dtype,
has_d,
d_has_hdim,
args.fuse_scale_d,
args.tolerance,
args.print_rtol_stats,
args.ref_lower_precision,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -35,6 +35,7 @@ import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
@ -166,6 +167,24 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
@ -264,7 +283,7 @@ class HopperWgmmaGemmKernel:
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
self.num_threads_per_warp_group = 128
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"]
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
self.ab_stage = None
self.epi_stage = None
@ -1309,7 +1328,7 @@ class HopperWgmmaGemmKernel:
}:
is_valid = False
# tested acc_dtype
if acc_dtype != cutlass.Float32:
if acc_dtype not in {cutlass.Float32, cutlass.Float16}:
is_valid = False
# tested c_dtype
if c_dtype not in {
@ -1335,7 +1354,7 @@ class HopperWgmmaGemmKernel:
return is_valid
def run_dense_gemm(
def run(
mnkl: Tuple[int, int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
@ -1347,9 +1366,43 @@ def run_dense_gemm(
tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
tolerance: float,
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""
Prepare A/B/C tensors, launch GPU kernel, and reference checking.
:param mnkl: Problem size (M, N, K, L)
:type mnkl: Tuple[int, int, int, int]
:param a_dtype: Data type for input tensor A
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: Data type for input tensor B
:type b_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Data type for accumulation during matrix multiplication
:type acc_dtype: Type[cutlass.Numeric]
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
:type a_major/b_major/c_major: str
:param tile_shape_mnk: CTA tile shape (M, N, K)
:type tile_shape_mnk: Tuple[int, int, int]
:param cluster_shape_mn: Cluster shape (M, N)
:type cluster_shape_mn: Tuple[int, int]
:param tolerance: Tolerance value for reference validation comparison
:type tolerance: float
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
:type use_cold_l2: bool, optional
:return: Execution time of the GEMM kernel in microseconds
:rtype: float
"""
print(f"Running Hopper Dense GEMM with:")
@ -1360,6 +1413,10 @@ def run_dense_gemm(
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {use_cold_l2}")
# Unpack parameters
m, n, k, l = mnkl
@ -1437,46 +1494,76 @@ def run_dense_gemm(
stream = cuda.CUstream(torch_stream.cuda_stream)
# compile gemm kernel
compiled_gemm = cute.compile(gemm, mA, mB, mC, stream)
# execution
compiled_gemm(mA, mB, mC, stream)
torch.cuda.synchronize()
if not skip_ref_check:
# execution
compiled_gemm(mA, mB, mC, stream)
# Ref check
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
torch.cuda.synchronize()
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
# m major: (l, n, m) -> (m, n, l)
# k major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch.uint8,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.SKIP,
).cuda()
# Create dtype cute tensor (gpu)
ref_c_tensor = from_dlpack(
f8_torch_tensor, assumed_align=16
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
ref_c_tensor.element_type = c_dtype
ref_c_tensor = cutlass_torch.convert_cute_tensor(
ref,
ref_c_tensor,
c_dtype,
is_dynamic_layout=True,
# Ref check
ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu()
if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2):
# m major: (l, n, m) -> (m, n, l)
# n major: (l, m, n) -> (m, n, l)
permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0)
shape = (l, m, n) if c_major == "n" else (l, n, m)
f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor(
shape,
torch.uint8,
permute_order=permute_order,
init_type=cutlass_torch.TensorInitType.SKIP,
).cuda()
# Create dtype cute tensor (gpu)
ref_c_tensor = from_dlpack(
f8_torch_tensor, assumed_align=16
).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
ref_c_tensor.element_type = c_dtype
ref_c_tensor = cutlass_torch.convert_cute_tensor(
ref,
ref_c_tensor,
c_dtype,
is_dynamic_layout=True,
)
ref_c = f8_torch_tensor.cpu()
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
def generate_tensors():
_, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype)
_, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype)
_, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype)
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
ref_c = f8_torch_tensor.cpu()
else:
ref_c = ref.to(cutlass_torch.dtype(c_dtype))
torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03)
exec_time = testing.benchmark(
compiled_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
args = parse_arguments()
run_dense_gemm(
run(
args.mnkl,
args.a_dtype,
args.b_dtype,
@ -1488,5 +1575,9 @@ if __name__ == "__main__":
args.tile_shape_mnk,
args.cluster_shape_mn,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")

View File

@ -399,6 +399,70 @@
"\n",
"tensor_print_example3()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def print_tensor_gpu(src: cute.Tensor):\n",
" print(src)\n",
" cute.print_tensor(src)\n",
"\n",
"@cute.jit\n",
"def print_tensor_host(src: cute.Tensor):\n",
" print_tensor_gpu(src).launch(grid=(1,1,1), block=(1,1,1))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
" [[-0.690547, -0.274619, -1.659539, ],\n",
" [-1.843524, -1.648711, 1.163431, ],\n",
" [-0.716668, -1.900705, 0.592515, ],\n",
" [ 0.711333, -0.552422, 0.860237, ]])\n"
]
}
],
"source": [
"import torch\n",
"def tensor_print_example4():\n",
" a = torch.randn(4, 3, device=\"cuda\")\n",
" cutlass.cuda.initialize_cuda_context()\n",
" print_tensor_host(from_dlpack(a))\n",
"\n",
"tensor_print_example4()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
]
}
],
"metadata": {

View File

@ -256,16 +256,6 @@
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
"\n",
"@cute.kernel\n",
"def print_tensor_gpu(ptr: cute.Pointer):\n",
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
" tensor = cute.make_tensor(ptr, layout)\n",
"\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
"\n",
" if tidx == 0:\n",
" cute.print_tensor(tensor)\n",
"\n",
"\n",
"# Create a tensor with sequential data using torch\n",
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",

View File

@ -363,7 +363,7 @@
"| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n",
"| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n",
"|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
"| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
"| | \"optimized\" | Optimized for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n",
"|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n",
"| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n",
"\n",